From 91370aa3e5d4af889afdaf906108c76907234ccf Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Mon, 19 Aug 2024 16:10:31 -0700 Subject: [PATCH 01/87] Index, Query by context + SA integration + Placeholder + Context Switch API Placeholder --- .gitignore | 3 +- .vscode/launch.json | 35 ++++--- graphrag/common/blob_storage_client.py | 58 ++++++++++++ graphrag/common/graph_db_client.py | 1 + graphrag/common/kusto_db_client.py | 1 + graphrag/config/__init__.py | 5 + graphrag/config/create_graphrag_config.py | 12 +++ graphrag/config/enums.py | 12 +++ .../query_context_config_input.py | 7 ++ graphrag/config/models/__init__.py | 2 + graphrag/config/models/graph_rag_config.py | 6 ++ .../config/models/query_context_config.py | 16 ++++ graphrag/index/__main__.py | 14 +++ graphrag/index/cli.py | 19 ++++ .../index/context_switch/contextSwitcher.py | 22 +++++ graphrag/index/input/text.py | 4 +- graphrag/index/verbs/text/embed/text_embed.py | 2 +- graphrag/query/__main__.py | 11 +++ graphrag/query/cli.py | 93 ++++++++++++++----- graphrag/query/indexer_adapters.py | 2 +- graphrag/vector_stores/kustodb.py | 41 ++++++++ graphrag/vector_stores/typing.py | 6 +- 22 files changed, 335 insertions(+), 37 deletions(-) create mode 100644 graphrag/common/blob_storage_client.py create mode 100644 graphrag/common/graph_db_client.py create mode 100644 graphrag/common/kusto_db_client.py create mode 100644 graphrag/config/input_models/query_context_config_input.py create mode 100644 graphrag/config/models/query_context_config.py create mode 100644 graphrag/index/context_switch/contextSwitcher.py create mode 100644 graphrag/vector_stores/kustodb.py diff --git a/.gitignore b/.gitignore index bff8e24810..75409a00d7 100644 --- a/.gitignore +++ b/.gitignore @@ -65,4 +65,5 @@ __blobstorage__/ ragtest/ .ragtest/ .pipelines -.pipeline \ No newline at end of file +.pipeline +ragtest*/ \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 909771b809..72f031baf1 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,12 +1,25 @@ { - "version": "0.2.0", - "configurations": [ - { - "name": "Attach to Node Functions", - "type": "node", - "request": "attach", - "port": 9229, - "preLaunchTask": "func: host start" - } - ] -} + "version": "0.2.0", + "configurations": [ + { + "name": "Indexer", + "type": "debugpy", + "python": "E:\\graphrag\\.venv\\Scripts\\python.exe", + "request": "launch", + "cwd": "${workspaceFolder}", + "module": "poetry", + "args": ["poe", "index", "--root", ".\\ragtest6"], + "stopOnEntry": false + }, + { + "name": "QUery", + "type": "debugpy", + "python": "E:\\graphrag\\.venv\\Scripts\\python.exe", + "request": "launch", + "cwd": "${workspaceFolder}", + "module": "poetry", + "args": ["poe", "query", "--root", ".\\ragtest6", "--method", "local", "Who provided access to Amritpal at first place to Unified Feedback KV Certificate?"], + "stopOnEntry": false + } + ] +} \ No newline at end of file diff --git a/graphrag/common/blob_storage_client.py b/graphrag/common/blob_storage_client.py new file mode 100644 index 0000000000..78dc809579 --- /dev/null +++ b/graphrag/common/blob_storage_client.py @@ -0,0 +1,58 @@ +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient + + +class BlobStorageClient: + """The Blob-Storage implementation.""" + + _connection_string: str | None + _container_name: str + _path_prefix: str + _encoding: str + _storage_account_blob_url: str | None + + def __init__( + self, + connection_string: str | None, + container_name: str, + encoding: str | None = None, + path_prefix: str | None = None, + storage_account_blob_url: str | None = None, + ): + """Create a new BlobStorage instance.""" + if connection_string: + self._blob_service_client = BlobServiceClient.from_connection_string( + connection_string + ) + else: + if storage_account_blob_url is None: + msg = "Either connection_string or storage_account_blob_url must be provided." + raise ValueError(msg) + + self._blob_service_client = BlobServiceClient( + account_url=storage_account_blob_url, + credential=DefaultAzureCredential(), + ) + self._encoding = encoding or "utf-8" + self._container_name = container_name + self._connection_string = connection_string + self._path_prefix = path_prefix or "" + self._storage_account_blob_url = storage_account_blob_url + self._storage_account_name = ( + storage_account_blob_url.split("//")[1].split(".")[0] + if storage_account_blob_url + else None + ) + #log.info( + # "creating blob storage at container=%s, path=%s", + # self._container_name, + # self._path_prefix, + #) + + def get_blob_service_client(self): + """Get the BlobServiceClient instance.""" + return self._blob_service_client + + def get_container_client(self): + """Get the container client instance.""" + return self._blob_service_client.get_container_client(self._container_name) diff --git a/graphrag/common/graph_db_client.py b/graphrag/common/graph_db_client.py new file mode 100644 index 0000000000..35b70dd385 --- /dev/null +++ b/graphrag/common/graph_db_client.py @@ -0,0 +1 @@ +# create Gremlin and cosmos db clients by reading settings from settings.yaml \ No newline at end of file diff --git a/graphrag/common/kusto_db_client.py b/graphrag/common/kusto_db_client.py new file mode 100644 index 0000000000..413a47341e --- /dev/null +++ b/graphrag/common/kusto_db_client.py @@ -0,0 +1 @@ +# create Gremlin and kusto db clients by reading settings from settings.yaml \ No newline at end of file diff --git a/graphrag/config/__init__.py b/graphrag/config/__init__.py index 118018a98f..5870c4ae71 100644 --- a/graphrag/config/__init__.py +++ b/graphrag/config/__init__.py @@ -8,6 +8,7 @@ ) from .enums import ( CacheType, + ContextSwitchType, InputFileType, InputType, LLMType, @@ -57,6 +58,7 @@ LLMParameters, LocalSearchConfig, ParallelizationParameters, + QueryContextConfig, ReportingConfig, SnapshotsConfig, StorageConfig, @@ -71,6 +73,7 @@ "AzureApiBaseMissingError", "AzureDeploymentNameMissingError", "CacheConfig", + "ContextSwitchType", "CacheConfigInput", "CacheType", "ChunkingConfig", @@ -102,6 +105,8 @@ "LocalSearchConfigInput", "ParallelizationParameters", "ParallelizationParametersInput", + "QueryContextConfig", + "QueryContextConfigInput", "ReportingConfig", "ReportingConfigInput", "ReportingType", diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 3504507be2..5686832736 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -47,6 +47,7 @@ LLMParameters, LocalSearchConfig, ParallelizationParameters, + QueryContextConfig, ReportingConfig, SnapshotsConfig, StorageConfig, @@ -539,6 +540,14 @@ def hydrate_parallelization_params( or defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS, concurrency=reader.int("concurrency") or defs.GLOBAL_SEARCH_CONCURRENCY, ) + + with ( + reader.use(values.get("query_context")), + reader.envvar_prefix(Section.query_context), + ): + query_context_model = QueryContextConfig( + files=reader.list("files") or [], + ) encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL skip_workflows = reader.list("skip_workflows") or [] @@ -566,6 +575,7 @@ def hydrate_parallelization_params( skip_workflows=skip_workflows, local_search=local_search_model, global_search=global_search_model, + query_context=query_context_model ) @@ -608,6 +618,7 @@ class Fragment(str, Enum): thread_stagger = "THREAD_STAGGER" tpm = "TOKENS_PER_MINUTE" type = "TYPE" + output = "OUTPUT" class Section(str, Enum): @@ -631,6 +642,7 @@ class Section(str, Enum): umap = "UMAP" local_search = "LOCAL_SEARCH" global_search = "GLOBAL_SEARCH" + query_context = "QUERY_CONTEXT" def _is_azure(llm_type: LLMType | None) -> bool: diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index 8741cf74ae..4745acc5f5 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -113,3 +113,15 @@ class LLMType(str, Enum): def __repr__(self): """Get a string representation.""" return f'"{self.value}"' + +class ContextSwitchType(str, Enum): + """context switcher type.""" + + #context switch types + Activate = "activate" + Deactivate= "deactivate" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + diff --git a/graphrag/config/input_models/query_context_config_input.py b/graphrag/config/input_models/query_context_config_input.py new file mode 100644 index 0000000000..c8f3d2e783 --- /dev/null +++ b/graphrag/config/input_models/query_context_config_input.py @@ -0,0 +1,7 @@ +from typing_extensions import NotRequired, TypedDict + +class QueryContextConfigInput(TypedDict): + """The default configuration section for Cache.""" + + files: NotRequired[str] + """The root path to run query on.""" diff --git a/graphrag/config/models/__init__.py b/graphrag/config/models/__init__.py index 43c4cde506..f2c5185c66 100644 --- a/graphrag/config/models/__init__.py +++ b/graphrag/config/models/__init__.py @@ -17,6 +17,7 @@ from .llm_parameters import LLMParameters from .local_search_config import LocalSearchConfig from .parallelization_parameters import ParallelizationParameters +from .query_context_config import QueryContextConfig from .reporting_config import ReportingConfig from .snapshots_config import SnapshotsConfig from .storage_config import StorageConfig @@ -39,6 +40,7 @@ "LLMParameters", "LocalSearchConfig", "ParallelizationParameters", + "QueryContextConfig", "ReportingConfig", "SnapshotsConfig", "StorageConfig", diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index 197d85e3e9..ab3cbd9fdd 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -19,6 +19,7 @@ from .input_config import InputConfig from .llm_config import LLMConfig from .local_search_config import LocalSearchConfig +from .query_context_config import QueryContextConfig from .reporting_config import ReportingConfig from .snapshots_config import SnapshotsConfig from .storage_config import StorageConfig @@ -144,3 +145,8 @@ def __str__(self): description="The workflows to skip, usually for testing reasons.", default=[] ) """The workflows to skip, usually for testing reasons.""" + + query_context: QueryContextConfig = Field( + description="The query context to use.", default=[] + ) + """The query context to use.""" \ No newline at end of file diff --git a/graphrag/config/models/query_context_config.py b/graphrag/config/models/query_context_config.py new file mode 100644 index 0000000000..15626efba9 --- /dev/null +++ b/graphrag/config/models/query_context_config.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class QueryContextConfig(BaseModel): + """The default configuration section for Cache.""" + files: list[str] = Field( + description="The list of the files on which query should be run.", + default=[] + ) \ No newline at end of file diff --git a/graphrag/index/__main__.py b/graphrag/index/__main__.py index 578ffc9c33..7eb68ffb2a 100644 --- a/graphrag/index/__main__.py +++ b/graphrag/index/__main__.py @@ -68,6 +68,18 @@ help="Overlay default configuration values on a provided configuration file (--config).", action="store_true", ) + parser.add_argument( + "--contextId", + help="Context id to activate or deactivate.", + action="store_true", + ) + parser.add_argument( + "--contextOperation", + help="Context operation activate or deactivate.", + # Only required if contextId is provided + action="store_true", + ) + args = parser.parse_args() if args.overlay_defaults and not args.config: @@ -86,4 +98,6 @@ init=args.init or False, overlay_defaults=args.overlay_defaults or False, cli=True, + contextId=args.contextId, + contextOperation=args.contextOperation, ) diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 68dbcf785f..506656111c 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -10,11 +10,13 @@ import sys import time import warnings +from enum import Enum from pathlib import Path from graphrag.config import ( create_graphrag_config, ) +from graphrag.config.enums import ContextSwitchType from graphrag.index import PipelineConfig, create_pipeline_config from graphrag.index.cache import NoopPipelineCache from graphrag.index.progress import ( @@ -71,6 +73,8 @@ def redact_dict(input: dict) -> dict: def index_cli( root: str, init: bool, + contextOperation: str | None, + contextId: str | None, verbose: bool, resume: str | None, memprofile: bool, @@ -97,6 +101,9 @@ def index_cli( pipeline_config: str | PipelineConfig = config or _create_default_config( root, None, verbose, dryrun or False, progress_reporter ) + if contextId: + _switch_context(pipeline_config, contextOperation, contextId, progress_reporter) + sys.exit(0) cache = NoopPipelineCache() if nocache else None pipeline_emit = emit.split(",") if emit else None encountered_errors = False @@ -170,6 +177,18 @@ async def execute(): if cli: sys.exit(1 if encountered_errors else 0) +def _switch_context(config: PipelineConfig | str, context_operation: str | None, context_id: str, reporter: ProgressReporter) -> None: + """Switch the context to the given context.""" + reporter.info(f"Switching context to {context_id} using operation {context_operation}") + from graphrag.index.context_switch.contextSwitcher import ContextSwitcher + context_switcher = ContextSwitcher() + if context_operation == ContextSwitchType.Activate: + context_switcher.activate(config, context_id, reporter) + elif context_operation == ContextSwitchType.Deactivate: + context_switcher.deactivate(config, context_id, reporter) + else: + msg = f"Invalid context operation {context_operation}" + raise ValueError(msg) def _initialize_project_at(path: str, reporter: ProgressReporter) -> None: """Initialize the project at the given path.""" diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py new file mode 100644 index 0000000000..631e1f97f3 --- /dev/null +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -0,0 +1,22 @@ +from graphrag.index.progress import ProgressReporter +from graphrag.index import PipelineConfig + +class ContextSwitcher: + """ContextSwitcher class definition.""" + + def __init__(self): + #initialize Gremline and Cosmos Db client here. + pass + def activate(self, config: PipelineConfig | str, contextId: str | None, reporter: ProgressReporter): + """Activate the context.""" + #1. read the context id to fileId mapping. + #2. read the file from storage. + #3. LanceDB: use cosmos db client to load data into Cosmos DB. + #4. KustoDB: use Kusto client to load embedding data into Kusto. + + return 0 + + def deactivate(self, config: PipelineConfig | str, contextId: str | None, reporter: ProgressReporter): + """DeActivate the context.""" + #1. Delete all the data for a given context id. + return 0 \ No newline at end of file diff --git a/graphrag/index/input/text.py b/graphrag/index/input/text.py index 2a676c0902..b2cb97225b 100644 --- a/graphrag/index/input/text.py +++ b/graphrag/index/input/text.py @@ -30,7 +30,7 @@ async def load( """Load text inputs from a directory.""" async def load_file( - path: str, group: dict | None = None, _encoding: str = "utf-8" + path: str, group: dict | None = None, _encoding: str = "utf-8" #what is group here, can be used as context? ) -> dict[str, Any]: if group is None: group = {} @@ -45,8 +45,10 @@ async def load_file( re.compile(config.file_pattern), progress=progress, file_filter=config.file_filter, + base_dir=config.base_dir, ) ) + #change here to run indexer on each file one by one. if len(files) == 0: msg = f"No text files found in {config.base_dir}" raise ValueError(msg) diff --git a/graphrag/index/verbs/text/embed/text_embed.py b/graphrag/index/verbs/text/embed/text_embed.py index 76ac97d76f..30f572d05b 100644 --- a/graphrag/index/verbs/text/embed/text_embed.py +++ b/graphrag/index/verbs/text/embed/text_embed.py @@ -75,7 +75,7 @@ async def text_embed( max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai vector_store: # The optional configuration for the vector store - type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb + type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb, kusto <...> ``` """ diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index edf678fa44..d572194c80 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -69,6 +69,13 @@ def __str__(self): default="Multiple Paragraphs", ) + parser.add_argument( + "--context_id", + help="Guid describing context in which the search should be performed", + type=str, + default="00000000-0000-0000-0000-000000000000", + ) + parser.add_argument( "query", nargs=1, @@ -76,6 +83,8 @@ def __str__(self): type=str, ) + + args = parser.parse_args() match args.method: @@ -86,6 +95,7 @@ def __str__(self): args.root, args.community_level, args.response_type, + args.context_id, args.query[0], ) case SearchType.GLOBAL: @@ -95,6 +105,7 @@ def __str__(self): args.root, args.community_level, args.response_type, + args.context_id, args.query[0], ) case _: diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 81efbb550b..b41c844ddd 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -6,6 +6,9 @@ import os from pathlib import Path from typing import cast +from io import BytesIO +from graphrag.config.enums import StorageType +from azure.core.exceptions import ResourceNotFoundError import pandas as pd @@ -20,6 +23,7 @@ ) from graphrag.vector_stores import VectorStoreFactory, VectorStoreType from graphrag.vector_stores.lancedb import LanceDBVectorStore +from graphrag.common.blob_storage_client import BlobStorageClient from .factories import get_global_search_engine, get_local_search_engine from .indexer_adapters import ( @@ -32,7 +36,6 @@ reporter = PrintProgressReporter("") - def __get_embedding_description_store( entities: list[Entity], vector_store_type: str = VectorStoreType.LanceDB, @@ -86,6 +89,7 @@ def run_global_search( root_dir: str | None, community_level: int, response_type: str, + context_id: str, query: str, ): """Run a global search with the given query.""" @@ -127,38 +131,52 @@ def run_local_search( root_dir: str | None, community_level: int, response_type: str, + context_id: str, query: str, ): """Run a local search with the given query.""" data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config_dir ) - data_path = Path(data_dir) - - final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet") - final_community_reports = pd.read_parquet( - data_path / "create_final_community_reports.parquet" - ) - final_text_units = pd.read_parquet(data_path / "create_final_text_units.parquet") - final_relationships = pd.read_parquet( - data_path / "create_final_relationships.parquet" - ) - final_entities = pd.read_parquet(data_path / "create_final_entities.parquet") - final_covariates_path = data_path / "create_final_covariates.parquet" - final_covariates = ( - pd.read_parquet(final_covariates_path) - if final_covariates_path.exists() - else None - ) + + data_paths = [] + data_paths = get_files_by_context(config, context_id) + #data_paths = [Path("E:\\graphrag\\ragtest6\\output\\AtoG\\artifacts")] + #data_paths = [Path("E:\\graphrag\\auditlogstest\\output\\securityPlatformPPE\\artifacts"),Path("E:\\graphrag\\auditlogstest\\output\\UnifiedFeedbackPPE\\artifacts")] + #data_paths.append(Path(data_dir)) + final_nodes = pd.DataFrame() + final_community_reports = pd.DataFrame() + final_text_units = pd.DataFrame() + final_relationships = pd.DataFrame() + final_entities = pd.DataFrame() + final_covariates = pd.DataFrame() + for data_path in data_paths: + #check from the config for the ouptut storage type and then read the data from the storage. + + #GraphDB: we may need to make change below to read nodes data from Graph DB + final_nodes = pd.concat([final_nodes, read_paraquet_file(config, data_path + "/create_final_nodes.parquet", config.storage.type)]) + + final_community_reports = pd.concat([final_community_reports,read_paraquet_file(config, data_path + "/create_final_community_reports.parquet", config.storage.type)]) + + final_text_units = pd.concat([final_text_units, read_paraquet_file(config, data_path + "/create_final_text_units.parquet", config.storage.type)]) + + final_relationships = pd.concat([final_relationships, read_paraquet_file(config, data_path + "/create_final_relationships.parquet", config.storage.type)]) + + + final_entities = pd.concat([final_entities, read_paraquet_file(config, data_path + "/create_final_entities.parquet", config.storage.type)]) + + final_covariates = pd.concat([final_covariates, ( + read_paraquet_file(config, data_path + "/create_final_covariates.parquet", config.storage.type) + )]) vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} ) reporter.info(f"Vector Store Args: {vector_store_args}") - vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) + vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) # verify kusto vector store here. - entities = read_indexer_entities(final_nodes, final_entities, community_level) + entities = read_indexer_entities(final_nodes, final_entities, community_level) # Change it to read file specific indexer files. description_embedding_store = __get_embedding_description_store( entities=entities, vector_store_type=vector_store_type, @@ -166,7 +184,7 @@ def run_local_search( ) covariates = ( read_indexer_covariates(final_covariates) - if final_covariates is not None + if final_covariates.empty is False else [] ) @@ -187,8 +205,41 @@ def run_local_search( reporter.success(f"Local Search Response: {result.response}") return result.response +def get_files_by_context(config: GraphRagConfig, context_id: str): + data_paths = config.query_context.files + return data_paths + +def blob_exists(container_client, blob_name): + blob_client = container_client.get_blob_client(blob_name) + try: + # Attempt to get the blob properties + blob_client.get_blob_properties() + return True + except ResourceNotFoundError: + # Blob does not exist + return False + + +def read_paraquet_file(config:GraphRagConfig, path: str, storageType: StorageType): + #create different enum for paraquet storage type + if storageType == StorageType.blob: + container_name = config.input.container_name or "" + blobStorageClient = BlobStorageClient(connection_string=config.input.connection_string, container_name=container_name, encoding="utf-8") + container_client = blobStorageClient.get_container_client() + if blob_exists(container_client, path): + blob_data = container_client.download_blob(blob=path) + bytes_io = BytesIO(blob_data.readall()) + return pd.read_parquet(bytes_io, engine="pyarrow") + else: + return pd.DataFrame() # return empty data frame as covariates file doesn't exist + else: + file_path = Path(path) + if not file_path.exists(): + raise ValueError(f"Data path {file_path} does not exist.") + return pd.read_parquet(path) def _configure_paths_and_settings( + data_dir: str | None, root_dir: str | None, config_dir: str | None, diff --git a/graphrag/query/indexer_adapters.py b/graphrag/query/indexer_adapters.py index 4ce90d9215..132db92c1f 100644 --- a/graphrag/query/indexer_adapters.py +++ b/graphrag/query/indexer_adapters.py @@ -47,7 +47,7 @@ def read_indexer_covariates(final_covariates: pd.DataFrame) -> list[Covariate]: text_unit_ids_col=None, ) - +# GraphDB: read relationshiops from the graph db. def read_indexer_relationships(final_relationships: pd.DataFrame) -> list[Relationship]: """Read in the Relationships from the raw indexing outputs.""" return read_relationships( diff --git a/graphrag/vector_stores/kustodb.py b/graphrag/vector_stores/kustodb.py new file mode 100644 index 0000000000..b009e2cbd6 --- /dev/null +++ b/graphrag/vector_stores/kustodb.py @@ -0,0 +1,41 @@ +# write kusto db here. +import lancedb as lancedb # noqa: I001 (Ruff was breaking on this file imports, even tho they were sorted and passed local tests) +from graphrag.model.types import TextEmbedder +from typing import Any + +from .base import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) +class KustoDBVectorStore(BaseVectorStore): + """Kusto vector store.""" + + def __init__(self, **kwargs): + """Initialize the Kusto vector store.""" + pass + + def connect(self, **kwargs: Any) -> Any: + """Connect to the vector storage.""" + pass + + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + """Load documents into vector storage.""" + pass + + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by id.""" + pass + + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a vector-based similarity search.""" + return [] + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a similarity search using a given input text.""" + return [] diff --git a/graphrag/vector_stores/typing.py b/graphrag/vector_stores/typing.py index 0b5a5cd195..b2b3386932 100644 --- a/graphrag/vector_stores/typing.py +++ b/graphrag/vector_stores/typing.py @@ -8,6 +8,7 @@ from .azure_ai_search import AzureAISearch from .lancedb import LanceDBVectorStore +from .kustodb import KustoDBVectorStore class VectorStoreType(str, Enum): @@ -15,6 +16,7 @@ class VectorStoreType(str, Enum): LanceDB = "lancedb" AzureAISearch = "azure_ai_search" + KustoDB = "kustodb" class VectorStoreFactory: @@ -30,13 +32,15 @@ def register(cls, vector_store_type: str, vector_store: type): @classmethod def get_vector_store( cls, vector_store_type: VectorStoreType | str, kwargs: dict - ) -> LanceDBVectorStore | AzureAISearch: + ) -> LanceDBVectorStore | AzureAISearch | KustoDBVectorStore: """Get the vector store type from a string.""" match vector_store_type: case VectorStoreType.LanceDB: return LanceDBVectorStore(**kwargs) case VectorStoreType.AzureAISearch: return AzureAISearch(**kwargs) + case VectorStoreType.KustoDB: + return KustoDBVectorStore() # KustoDB: Pass required arguments here case _: if vector_store_type in cls.vector_store_types: return cls.vector_store_types[vector_store_type](**kwargs) From 649edcd26fb5450c4b4363ff3a0b31914d608fb1 Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Tue, 20 Aug 2024 10:59:53 -0700 Subject: [PATCH 02/87] Minor changes + Context Swtich flow not working --- README.md | 4 ++++ graphrag/common/utils/common_utils.py | 11 +++++++++++ graphrag/common/utils/context_utils.py | 9 +++++++++ graphrag/config/create_graphrag_config.py | 2 ++ graphrag/config/models/storage_config.py | 3 +++ graphrag/index/__main__.py | 4 ++-- graphrag/index/cli.py | 13 +++++++++---- graphrag/index/context_switch/contextSwitcher.py | 11 ++++++----- graphrag/index/init_content.py | 3 +++ graphrag/index/input/text.py | 7 ++++++- graphrag/index/storage/blob_pipeline_storage.py | 3 +++ graphrag/query/cli.py | 9 +++------ 12 files changed, 61 insertions(+), 18 deletions(-) create mode 100644 graphrag/common/utils/common_utils.py create mode 100644 graphrag/common/utils/context_utils.py diff --git a/README.md b/README.md index 0b936058ce..07b5b66975 100644 --- a/README.md +++ b/README.md @@ -69,3 +69,7 @@ Any use of third-party trademarks or logos are subject to those third-party's po ## Privacy [Microsoft Privacy Statement](https://privacy.microsoft.com/en-us/privacystatement) + + +## Updates +- add new settings query_context -> files [file1, file2, file3] diff --git a/graphrag/common/utils/common_utils.py b/graphrag/common/utils/common_utils.py new file mode 100644 index 0000000000..d53345d245 --- /dev/null +++ b/graphrag/common/utils/common_utils.py @@ -0,0 +1,11 @@ +import uuid + +def is_valid_guid(guid_str): + """Utility to check valid Guid.""" + try: + # Attempt to create a UUID object + uuid_obj = uuid.UUID(guid_str, version=4) + # Check if the string representation matches the UUID object + return str(uuid_obj) == guid_str + except ValueError: + return False \ No newline at end of file diff --git a/graphrag/common/utils/context_utils.py b/graphrag/common/utils/context_utils.py new file mode 100644 index 0000000000..687779e9c1 --- /dev/null +++ b/graphrag/common/utils/context_utils.py @@ -0,0 +1,9 @@ +from graphrag.config import ( + GraphRagConfig, +) + +def get_files_by_contextid(config: GraphRagConfig, context_id: str): + """Utility function to get files by context id""" + # General: eventually this will be comming from cosmos db or any other storage + filesInContext = config.query_context.files + return filesInContext \ No newline at end of file diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 5686832736..0903991548 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -381,6 +381,7 @@ def hydrate_parallelization_params( storage_account_blob_url=reader.str(Fragment.storage_account_blob_url), container_name=reader.str(Fragment.container_name), base_dir=reader.str(Fragment.base_dir) or defs.STORAGE_BASE_DIR, + overwrite=reader.bool(Fragment.overwrite) or False ) with reader.envvar_prefix(Section.chunk), reader.use(values.get("chunks")): group_by_columns = reader.list("group_by_columns", "BY_COLUMNS") @@ -589,6 +590,7 @@ class Fragment(str, Enum): api_proxy = "API_PROXY" async_mode = "ASYNC_MODE" base_dir = "BASE_DIR" + overwrite = "Overwrite" cognitive_services_endpoint = "COGNITIVE_SERVICES_ENDPOINT" concurrent_requests = "CONCURRENT_REQUESTS" conn_string = "CONNECTION_STRING" diff --git a/graphrag/config/models/storage_config.py b/graphrag/config/models/storage_config.py index dcf41b9222..b3b5c70fe0 100644 --- a/graphrag/config/models/storage_config.py +++ b/graphrag/config/models/storage_config.py @@ -28,3 +28,6 @@ class StorageConfig(BaseModel): storage_account_blob_url: str | None = Field( description="The storage account blob url to use.", default=None ) + overwrite: bool = Field( + description="If true, don't throw error overwrite existing containers otherwise throw error", default= False + ) diff --git a/graphrag/index/__main__.py b/graphrag/index/__main__.py index 7eb68ffb2a..a498ba994e 100644 --- a/graphrag/index/__main__.py +++ b/graphrag/index/__main__.py @@ -71,13 +71,13 @@ parser.add_argument( "--contextId", help="Context id to activate or deactivate.", - action="store_true", + type=str ) parser.add_argument( "--contextOperation", help="Context operation activate or deactivate.", # Only required if contextId is provided - action="store_true", + type=str ) args = parser.parse_args() diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 506656111c..99471c9b6a 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -10,13 +10,14 @@ import sys import time import warnings -from enum import Enum from pathlib import Path from graphrag.config import ( + GraphRagConfig, create_graphrag_config, ) from graphrag.config.enums import ContextSwitchType +from graphrag.common.utils.common_utils import is_valid_guid from graphrag.index import PipelineConfig, create_pipeline_config from graphrag.index.cache import NoopPipelineCache from graphrag.index.progress import ( @@ -39,7 +40,6 @@ log = logging.getLogger(__name__) - def redact(input: dict) -> str: """Sanitize the config json.""" @@ -102,7 +102,12 @@ def index_cli( root, None, verbose, dryrun or False, progress_reporter ) if contextId: - _switch_context(pipeline_config, contextOperation, contextId, progress_reporter) + if not is_valid_guid(contextId): + ValueError("ContextId is invalid: It should be a valid Guid") + if (contextOperation != ContextSwitchType.Activate or contextOperation != ContextSwitchType.Deactivate): + ValueError("ContextOperation is invalid: It should be Active or DeActive") + graphrag_config = _read_config_parameters(root, config, progress_reporter) + _switch_context(graphrag_config, contextOperation, contextId, progress_reporter) sys.exit(0) cache = NoopPipelineCache() if nocache else None pipeline_emit = emit.split(",") if emit else None @@ -177,7 +182,7 @@ async def execute(): if cli: sys.exit(1 if encountered_errors else 0) -def _switch_context(config: PipelineConfig | str, context_operation: str | None, context_id: str, reporter: ProgressReporter) -> None: +def _switch_context(config: GraphRagConfig | str, context_operation: str | None, context_id: str, reporter: ProgressReporter) -> None: """Switch the context to the given context.""" reporter.info(f"Switching context to {context_id} using operation {context_operation}") from graphrag.index.context_switch.contextSwitcher import ContextSwitcher diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index 631e1f97f3..7ff1fa023f 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -1,5 +1,5 @@ from graphrag.index.progress import ProgressReporter -from graphrag.index import PipelineConfig +from graphrag.config import GraphRagConfig class ContextSwitcher: """ContextSwitcher class definition.""" @@ -7,16 +7,17 @@ class ContextSwitcher: def __init__(self): #initialize Gremline and Cosmos Db client here. pass - def activate(self, config: PipelineConfig | str, contextId: str | None, reporter: ProgressReporter): + def activate(self, config: GraphRagConfig | str, contextId: str | None, reporter: ProgressReporter): """Activate the context.""" + #1. read the context id to fileId mapping. - #2. read the file from storage. - #3. LanceDB: use cosmos db client to load data into Cosmos DB. + #2. read the file from storage using common/blob_storage_client.py + #3. GraphDB: use cosmos db client to load data into Cosmos DB. #4. KustoDB: use Kusto client to load embedding data into Kusto. return 0 - def deactivate(self, config: PipelineConfig | str, contextId: str | None, reporter: ProgressReporter): + def deactivate(self, config: GraphRagConfig | str, contextId: str | None, reporter: ProgressReporter): """DeActivate the context.""" #1. Delete all the data for a given context id. return 0 \ No newline at end of file diff --git a/graphrag/index/init_content.py b/graphrag/index/init_content.py index d2ad3906b1..af4e11b606 100644 --- a/graphrag/index/init_content.py +++ b/graphrag/index/init_content.py @@ -158,6 +158,9 @@ # map_max_tokens: {defs.GLOBAL_SEARCH_MAP_MAX_TOKENS} # reduce_max_tokens: {defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS} # concurrency: {defs.GLOBAL_SEARCH_CONCURRENCY} + +query_context: + # Files: [] # list of files in context to run query """ INIT_DOTENV = """ diff --git a/graphrag/index/input/text.py b/graphrag/index/input/text.py index b2cb97225b..eceb8bb06a 100644 --- a/graphrag/index/input/text.py +++ b/graphrag/index/input/text.py @@ -48,10 +48,15 @@ async def load_file( base_dir=config.base_dir, ) ) - #change here to run indexer on each file one by one. + if len(files) == 0: msg = f"No text files found in {config.base_dir}" raise ValueError(msg) + + if len(files) > 1: + msg = f"found more than 1 files in base dir {config.base_dir}" + raise ValueError(msg) + found_files = f"found text files from {config.base_dir}, found {files}" log.info(found_files) diff --git a/graphrag/index/storage/blob_pipeline_storage.py b/graphrag/index/storage/blob_pipeline_storage.py index 7e60df9697..e68a86378b 100644 --- a/graphrag/index/storage/blob_pipeline_storage.py +++ b/graphrag/index/storage/blob_pipeline_storage.py @@ -36,6 +36,7 @@ def __init__( encoding: str | None = None, path_prefix: str | None = None, storage_account_blob_url: str | None = None, + overwrite: bool = False ): """Create a new BlobStorage instance.""" if connection_string: @@ -196,6 +197,8 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: self._container_name ) blob_client = container_client.get_blob_client(key) + if blob_client.exists(): + ValueError("Artifacts already exists, make sure output folder is empty.") if isinstance(value, bytes): blob_client.upload_blob(value, overwrite=True) else: diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index b41c844ddd..7c38ad8733 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -7,14 +7,15 @@ from pathlib import Path from typing import cast from io import BytesIO +from graphrag.common.utils.context_utils import get_files_by_contextid from graphrag.config.enums import StorageType from azure.core.exceptions import ResourceNotFoundError import pandas as pd from graphrag.config import ( - GraphRagConfig, create_graphrag_config, + GraphRagConfig, ) from graphrag.index.progress import PrintProgressReporter from graphrag.model.entity import Entity @@ -140,7 +141,7 @@ def run_local_search( ) data_paths = [] - data_paths = get_files_by_context(config, context_id) + data_paths = get_files_by_contextid(config, context_id) #data_paths = [Path("E:\\graphrag\\ragtest6\\output\\AtoG\\artifacts")] #data_paths = [Path("E:\\graphrag\\auditlogstest\\output\\securityPlatformPPE\\artifacts"),Path("E:\\graphrag\\auditlogstest\\output\\UnifiedFeedbackPPE\\artifacts")] #data_paths.append(Path(data_dir)) @@ -204,10 +205,6 @@ def run_local_search( result = search_engine.search(query=query) reporter.success(f"Local Search Response: {result.response}") return result.response - -def get_files_by_context(config: GraphRagConfig, context_id: str): - data_paths = config.query_context.files - return data_paths def blob_exists(container_client, blob_name): blob_client = container_client.get_blob_client(blob_name) From 4875079d5e1e5c2bc6eadd08efb0bf5ff8e532ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Tue, 20 Aug 2024 13:27:24 -0700 Subject: [PATCH 03/87] Writing edges and vertices into graphdb --- graphrag/index/create_pipeline_config.py | 2 + graphrag/index/emit/factories.py | 3 + graphrag/index/emit/types.py | 1 + poetry.lock | 459 ++++++++++++++++++++++- pyproject.toml | 1 + 5 files changed, 465 insertions(+), 1 deletion(-) diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index 22dba20029..4472ef8828 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -6,6 +6,7 @@ import json import logging from pathlib import Path +from .emit.types import TableEmitterType from graphrag.config.enums import ( CacheType, @@ -354,6 +355,7 @@ def _graph_workflows( ), "skip_name_embedding": skip_entity_name_embedding, "skip_description_embedding": skip_entity_description_embedding, + "emitter_type": TableEmitterType.Graphdb, }, ), PipelineWorkflowReference( diff --git a/graphrag/index/emit/factories.py b/graphrag/index/emit/factories.py index 84afa68443..cd9e203917 100644 --- a/graphrag/index/emit/factories.py +++ b/graphrag/index/emit/factories.py @@ -9,6 +9,7 @@ from .csv_table_emitter import CSVTableEmitter from .json_table_emitter import JsonTableEmitter from .parquet_table_emitter import ParquetTableEmitter +from .graph_db_emitter import GraphDBEmitter from .table_emitter import TableEmitter from .types import TableEmitterType @@ -24,6 +25,8 @@ def create_table_emitter( return ParquetTableEmitter(storage, on_error) case TableEmitterType.CSV: return CSVTableEmitter(storage) + case TableEmitterType.Graphdb: + return GraphDBEmitter() case _: msg = f"Unsupported table emitter type: {emitter_type}" raise ValueError(msg) diff --git a/graphrag/index/emit/types.py b/graphrag/index/emit/types.py index ab3452856f..0b0ff88541 100644 --- a/graphrag/index/emit/types.py +++ b/graphrag/index/emit/types.py @@ -12,3 +12,4 @@ class TableEmitterType(str, Enum): Json = "json" Parquet = "parquet" CSV = "csv" + Graphdb = "graphdb" diff --git a/poetry.lock b/poetry.lock index 278c8ad98b..49b84c97e0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,17 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "aenum" +version = "3.1.15" +description = "Advanced Enumerations (compatible with Python's stdlib Enum), NamedTuples, and NamedConstants" +optional = false +python-versions = "*" +files = [ + {file = "aenum-3.1.15-py2-none-any.whl", hash = "sha256:27b1710b9d084de6e2e695dab78fe9f269de924b51ae2850170ee7e1ca6288a5"}, + {file = "aenum-3.1.15-py3-none-any.whl", hash = "sha256:e0dfaeea4c2bd362144b87377e2c61d91958c5ed0b4daf89cb6f45ae23af6288"}, + {file = "aenum-3.1.15.tar.gz", hash = "sha256:8cbd76cd18c4f870ff39b24284d3ea028fbe8731a58df3aa581e434c575b9559"}, +] + [[package]] name = "aiofiles" version = "24.1.0" @@ -11,6 +23,114 @@ files = [ {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, ] +[[package]] +name = "aiohappyeyeballs" +version = "2.3.7" +description = "Happy Eyeballs for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiohappyeyeballs-2.3.7-py3-none-any.whl", hash = "sha256:337ce4dc0e99eb697c3c5a77d6cb3c52925824d9a67ac0dea7c55b8a2d60b222"}, + {file = "aiohappyeyeballs-2.3.7.tar.gz", hash = "sha256:e794cd29ba6a14078092984e43688212a19081de3a73b6796c2fdeb3706dd6ce"}, +] + +[[package]] +name = "aiohttp" +version = "3.10.3" +description = "Async http client/server framework (asyncio)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiohttp-3.10.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc36cbdedf6f259371dbbbcaae5bb0e95b879bc501668ab6306af867577eb5db"}, + {file = "aiohttp-3.10.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:85466b5a695c2a7db13eb2c200af552d13e6a9313d7fa92e4ffe04a2c0ea74c1"}, + {file = "aiohttp-3.10.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:71bb1d97bfe7e6726267cea169fdf5df7658831bb68ec02c9c6b9f3511e108bb"}, + {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baec1eb274f78b2de54471fc4c69ecbea4275965eab4b556ef7a7698dee18bf2"}, + {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13031e7ec1188274bad243255c328cc3019e36a5a907978501256000d57a7201"}, + {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2bbc55a964b8eecb341e492ae91c3bd0848324d313e1e71a27e3d96e6ee7e8e8"}, + {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8cc0564b286b625e673a2615ede60a1704d0cbbf1b24604e28c31ed37dc62aa"}, + {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f817a54059a4cfbc385a7f51696359c642088710e731e8df80d0607193ed2b73"}, + {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8542c9e5bcb2bd3115acdf5adc41cda394e7360916197805e7e32b93d821ef93"}, + {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:671efce3a4a0281060edf9a07a2f7e6230dca3a1cbc61d110eee7753d28405f7"}, + {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:0974f3b5b0132edcec92c3306f858ad4356a63d26b18021d859c9927616ebf27"}, + {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:44bb159b55926b57812dca1b21c34528e800963ffe130d08b049b2d6b994ada7"}, + {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6ae9ae382d1c9617a91647575255ad55a48bfdde34cc2185dd558ce476bf16e9"}, + {file = "aiohttp-3.10.3-cp310-cp310-win32.whl", hash = "sha256:aed12a54d4e1ee647376fa541e1b7621505001f9f939debf51397b9329fd88b9"}, + {file = "aiohttp-3.10.3-cp310-cp310-win_amd64.whl", hash = "sha256:b51aef59370baf7444de1572f7830f59ddbabd04e5292fa4218d02f085f8d299"}, + {file = "aiohttp-3.10.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e021c4c778644e8cdc09487d65564265e6b149896a17d7c0f52e9a088cc44e1b"}, + {file = "aiohttp-3.10.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:24fade6dae446b183e2410a8628b80df9b7a42205c6bfc2eff783cbeedc224a2"}, + {file = "aiohttp-3.10.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bc8e9f15939dacb0e1f2d15f9c41b786051c10472c7a926f5771e99b49a5957f"}, + {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5a9ec959b5381271c8ec9310aae1713b2aec29efa32e232e5ef7dcca0df0279"}, + {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a5d0ea8a6467b15d53b00c4e8ea8811e47c3cc1bdbc62b1aceb3076403d551f"}, + {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c9ed607dbbdd0d4d39b597e5bf6b0d40d844dfb0ac6a123ed79042ef08c1f87e"}, + {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3e66d5b506832e56add66af88c288c1d5ba0c38b535a1a59e436b300b57b23e"}, + {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fda91ad797e4914cca0afa8b6cccd5d2b3569ccc88731be202f6adce39503189"}, + {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:61ccb867b2f2f53df6598eb2a93329b5eee0b00646ee79ea67d68844747a418e"}, + {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:6d881353264e6156f215b3cb778c9ac3184f5465c2ece5e6fce82e68946868ef"}, + {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:b031ce229114825f49cec4434fa844ccb5225e266c3e146cb4bdd025a6da52f1"}, + {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5337cc742a03f9e3213b097abff8781f79de7190bbfaa987bd2b7ceb5bb0bdec"}, + {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ab3361159fd3dcd0e48bbe804006d5cfb074b382666e6c064112056eb234f1a9"}, + {file = "aiohttp-3.10.3-cp311-cp311-win32.whl", hash = "sha256:05d66203a530209cbe40f102ebaac0b2214aba2a33c075d0bf825987c36f1f0b"}, + {file = "aiohttp-3.10.3-cp311-cp311-win_amd64.whl", hash = "sha256:70b4a4984a70a2322b70e088d654528129783ac1ebbf7dd76627b3bd22db2f17"}, + {file = "aiohttp-3.10.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:166de65e2e4e63357cfa8417cf952a519ac42f1654cb2d43ed76899e2319b1ee"}, + {file = "aiohttp-3.10.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7084876352ba3833d5d214e02b32d794e3fd9cf21fdba99cff5acabeb90d9806"}, + {file = "aiohttp-3.10.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d98c604c93403288591d7d6d7d6cc8a63459168f8846aeffd5b3a7f3b3e5e09"}, + {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d73b073a25a0bb8bf014345374fe2d0f63681ab5da4c22f9d2025ca3e3ea54fc"}, + {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8da6b48c20ce78f5721068f383e0e113dde034e868f1b2f5ee7cb1e95f91db57"}, + {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3a9dcdccf50284b1b0dc72bc57e5bbd3cc9bf019060dfa0668f63241ccc16aa7"}, + {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56fb94bae2be58f68d000d046172d8b8e6b1b571eb02ceee5535e9633dcd559c"}, + {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bf75716377aad2c718cdf66451c5cf02042085d84522aec1f9246d3e4b8641a6"}, + {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6c51ed03e19c885c8e91f574e4bbe7381793f56f93229731597e4a499ffef2a5"}, + {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b84857b66fa6510a163bb083c1199d1ee091a40163cfcbbd0642495fed096204"}, + {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c124b9206b1befe0491f48185fd30a0dd51b0f4e0e7e43ac1236066215aff272"}, + {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3461d9294941937f07bbbaa6227ba799bc71cc3b22c40222568dc1cca5118f68"}, + {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:08bd0754d257b2db27d6bab208c74601df6f21bfe4cb2ec7b258ba691aac64b3"}, + {file = "aiohttp-3.10.3-cp312-cp312-win32.whl", hash = "sha256:7f9159ae530297f61a00116771e57516f89a3de6ba33f314402e41560872b50a"}, + {file = "aiohttp-3.10.3-cp312-cp312-win_amd64.whl", hash = "sha256:e1128c5d3a466279cb23c4aa32a0f6cb0e7d2961e74e9e421f90e74f75ec1edf"}, + {file = "aiohttp-3.10.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:d1100e68e70eb72eadba2b932b185ebf0f28fd2f0dbfe576cfa9d9894ef49752"}, + {file = "aiohttp-3.10.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a541414578ff47c0a9b0b8b77381ea86b0c8531ab37fc587572cb662ccd80b88"}, + {file = "aiohttp-3.10.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d5548444ef60bf4c7b19ace21f032fa42d822e516a6940d36579f7bfa8513f9c"}, + {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ba2e838b5e6a8755ac8297275c9460e729dc1522b6454aee1766c6de6d56e5e"}, + {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:48665433bb59144aaf502c324694bec25867eb6630fcd831f7a893ca473fcde4"}, + {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bac352fceed158620ce2d701ad39d4c1c76d114255a7c530e057e2b9f55bdf9f"}, + {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b0f670502100cdc567188c49415bebba947eb3edaa2028e1a50dd81bd13363f"}, + {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43b09f38a67679e32d380fe512189ccb0b25e15afc79b23fbd5b5e48e4fc8fd9"}, + {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:cd788602e239ace64f257d1c9d39898ca65525583f0fbf0988bcba19418fe93f"}, + {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:214277dcb07ab3875f17ee1c777d446dcce75bea85846849cc9d139ab8f5081f"}, + {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:32007fdcaab789689c2ecaaf4b71f8e37bf012a15cd02c0a9db8c4d0e7989fa8"}, + {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:123e5819bfe1b87204575515cf448ab3bf1489cdeb3b61012bde716cda5853e7"}, + {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:812121a201f0c02491a5db335a737b4113151926a79ae9ed1a9f41ea225c0e3f"}, + {file = "aiohttp-3.10.3-cp38-cp38-win32.whl", hash = "sha256:b97dc9a17a59f350c0caa453a3cb35671a2ffa3a29a6ef3568b523b9113d84e5"}, + {file = "aiohttp-3.10.3-cp38-cp38-win_amd64.whl", hash = "sha256:3731a73ddc26969d65f90471c635abd4e1546a25299b687e654ea6d2fc052394"}, + {file = "aiohttp-3.10.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38d91b98b4320ffe66efa56cb0f614a05af53b675ce1b8607cdb2ac826a8d58e"}, + {file = "aiohttp-3.10.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9743fa34a10a36ddd448bba8a3adc2a66a1c575c3c2940301bacd6cc896c6bf1"}, + {file = "aiohttp-3.10.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7c126f532caf238031c19d169cfae3c6a59129452c990a6e84d6e7b198a001dc"}, + {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:926e68438f05703e500b06fe7148ef3013dd6f276de65c68558fa9974eeb59ad"}, + {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:434b3ab75833accd0b931d11874e206e816f6e6626fd69f643d6a8269cd9166a"}, + {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d35235a44ec38109b811c3600d15d8383297a8fab8e3dec6147477ec8636712a"}, + {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59c489661edbd863edb30a8bd69ecb044bd381d1818022bc698ba1b6f80e5dd1"}, + {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50544fe498c81cb98912afabfc4e4d9d85e89f86238348e3712f7ca6a2f01dab"}, + {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:09bc79275737d4dc066e0ae2951866bb36d9c6b460cb7564f111cc0427f14844"}, + {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:af4dbec58e37f5afff4f91cdf235e8e4b0bd0127a2a4fd1040e2cad3369d2f06"}, + {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:b22cae3c9dd55a6b4c48c63081d31c00fc11fa9db1a20c8a50ee38c1a29539d2"}, + {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ba562736d3fbfe9241dad46c1a8994478d4a0e50796d80e29d50cabe8fbfcc3f"}, + {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:f25d6c4e82d7489be84f2b1c8212fafc021b3731abdb61a563c90e37cced3a21"}, + {file = "aiohttp-3.10.3-cp39-cp39-win32.whl", hash = "sha256:b69d832e5f5fa15b1b6b2c8eb6a9fd2c0ec1fd7729cb4322ed27771afc9fc2ac"}, + {file = "aiohttp-3.10.3-cp39-cp39-win_amd64.whl", hash = "sha256:673bb6e3249dc8825df1105f6ef74e2eab779b7ff78e96c15cadb78b04a83752"}, + {file = "aiohttp-3.10.3.tar.gz", hash = "sha256:21650e7032cc2d31fc23d353d7123e771354f2a3d5b05a5647fc30fea214e696"}, +] + +[package.dependencies] +aiohappyeyeballs = ">=2.3.0" +aiosignal = ">=1.1.2" +async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} +attrs = ">=17.3.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +yarl = ">=1.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] + [[package]] name = "aiolimiter" version = "1.1.0" @@ -22,6 +142,20 @@ files = [ {file = "aiolimiter-1.1.0.tar.gz", hash = "sha256:461cf02f82a29347340d031626c92853645c099cb5ff85577b831a7bd21132b5"}, ] +[[package]] +name = "aiosignal" +version = "1.3.1" +description = "aiosignal: a list of registered asynchronous callbacks" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, + {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" + [[package]] name = "annotated-types" version = "0.7.0" @@ -188,6 +322,17 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, +] + [[package]] name = "attrs" version = "24.1.0" @@ -1204,6 +1349,7 @@ files = [ {file = "fastparquet-2024.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5626fc72204001b7e82fedb4b02174ecb4e2d4143b38b4ea8d2f9eb65f6b000e"}, {file = "fastparquet-2024.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c8b2e86fe6488cce0e3d41263bb0296ef9bbb875a2fca09d67d7685640017a66"}, {file = "fastparquet-2024.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2a951106782d51e5ab110beaad29c4aa0537f045711bb0bf146f65aeaed14174"}, + {file = "fastparquet-2024.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:cd3473d3e299bfb04c0ac7726cca5d13ee450cc2387ee7fd70587ca150647315"}, {file = "fastparquet-2024.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:47695037fdc534ef4247f25ccf17dcbd8825be6ecb70c54ca54d588a794f4a6d"}, {file = "fastparquet-2024.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc3d35ff8341cd65baecac71062e9d73393d7afda207b3421709c1d3f4baa194"}, {file = "fastparquet-2024.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:691348cc85890663dd3c0bb02544d38d4c07a0c3d68837324dc01007301150b5"}, @@ -1301,6 +1447,92 @@ files = [ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] +[[package]] +name = "frozenlist" +version = "1.4.1" +description = "A list-like structure which implements collections.abc.MutableSequence" +optional = false +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, + {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, + {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, + {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, + {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, + {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, + {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, + {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, + {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, + {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, + {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, + {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, + {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, +] + [[package]] name = "fsspec" version = "2024.6.1" @@ -1439,6 +1671,28 @@ files = [ {file = "graspologic_native-1.2.1.tar.gz", hash = "sha256:72b7586028a91e9fef9af0ef314d368f0240c18dca99e6e6c546334359a8610a"}, ] +[[package]] +name = "gremlinpython" +version = "3.7.2" +description = "Gremlin-Python for Apache TinkerPop" +optional = false +python-versions = "*" +files = [ + {file = "gremlinpython-3.7.2-py2.py3-none-any.whl", hash = "sha256:ff925632f70fc5afff6864627fa69de8df577090a6eea63f10f0fed7194f98bf"}, + {file = "gremlinpython-3.7.2.tar.gz", hash = "sha256:58d12e4af81210d7770d54da8f1f586de816bc00858e095f76326d338ea34acd"}, +] + +[package.dependencies] +aenum = ">=1.4.5,<4.0.0" +aiohttp = ">=3.8.0,<4.0.0" +isodate = ">=0.6.0,<1.0.0" +nest-asyncio = "*" +six = ">=1.10.0,<2.0.0" + +[package.extras] +kerberos = ["kerberos (>=1.3.0,<2.0.0)"] +ujson = ["ujson (>=2.0.0)"] + [[package]] name = "h11" version = "0.14.0" @@ -2522,6 +2776,105 @@ files = [ msal = ">=1.29,<2" portalocker = ">=1.4,<3" +[[package]] +name = "multidict" +version = "6.0.5" +description = "multidict implementation" +optional = false +python-versions = ">=3.7" +files = [ + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, + {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, + {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, + {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, + {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, + {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, + {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, + {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, + {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, + {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, + {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, + {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, + {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, + {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, + {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, + {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, +] + [[package]] name = "nbclient" version = "0.10.0" @@ -3786,6 +4139,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -5143,6 +5497,109 @@ files = [ {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, ] +[[package]] +name = "yarl" +version = "1.9.4" +description = "Yet another URL library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, + {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, + {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, + {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, + {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, + {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, + {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, + {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, + {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, + {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, + {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, + {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, + {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, + {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, + {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, + {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" + [[package]] name = "zipp" version = "3.19.2" @@ -5161,4 +5618,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "3b42a857248dd22d200eb65786cca8e7ce0378fb603083c62a6d11a8f47f8b45" +content-hash = "b2c972b8acfe25a9da7a68b3c2694a431908d66738b8082d48bfe124a78585b9" diff --git a/pyproject.toml b/pyproject.toml index 7734066dc5..005a4917b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ typing-extensions = "^4.12.2" azure-storage-blob = "^12.19.0" azure-identity = "^1.17.1" json-repair = "^0.25.3" +gremlinpython = "^3.7.2" [tool.poetry.group.dev.dependencies] coverage = "^7.6.0" From a0f9a541864283c924a8b370e83e9ed592f891b0 Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Tue, 20 Aug 2024 16:41:11 -0700 Subject: [PATCH 04/87] Fixed context switching issue --- .vscode/launch.json | 16 +++++++++++++--- graphrag/index/__main__.py | 28 +++++++++++++++------------- graphrag/index/cli.py | 12 ++++++------ 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 72f031baf1..07d7ab4afc 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -8,17 +8,27 @@ "request": "launch", "cwd": "${workspaceFolder}", "module": "poetry", - "args": ["poe", "index", "--root", ".\\ragtest6"], + "args": ["poe", "index", "","", "--root", ".\\ragtest6"], "stopOnEntry": false }, { - "name": "QUery", + "name": "Query", "type": "debugpy", "python": "E:\\graphrag\\.venv\\Scripts\\python.exe", "request": "launch", "cwd": "${workspaceFolder}", "module": "poetry", - "args": ["poe", "query", "--root", ".\\ragtest6", "--method", "local", "Who provided access to Amritpal at first place to Unified Feedback KV Certificate?"], + "args": ["poe", "query", "--root", ".\\ragtest6", "--method", "local", "How Guillermo got certificate access?"], + "stopOnEntry": false + }, + { + "name": "Context Switch", + "type": "debugpy", + "python": "E:\\graphrag\\.venv\\Scripts\\python.exe", + "request": "launch", + "cwd": "${workspaceFolder}", + "module": "poetry", + "args": ["poe", "index", "--context-id", "f367f396-b0bf-4547-8d9c-331fc7a39433", "--context-operation", "activate", "--root", ".\\ragtest6"], "stopOnEntry": false } ] diff --git a/graphrag/index/__main__.py b/graphrag/index/__main__.py index a498ba994e..7bdac87734 100644 --- a/graphrag/index/__main__.py +++ b/graphrag/index/__main__.py @@ -52,6 +52,19 @@ help="The data formats to emit, comma-separated. Valid values are 'parquet' and 'csv'. default='parquet,csv'", type=str, ) + parser.add_argument( + "--context-id", + required=False, + help="Context id to activate or deactivate.", + type=str + ) + parser.add_argument( + "--context-operation", + help="Context operation activate or deactivate.", + required=False, + # Only required if contextId is provided + type=str + ) parser.add_argument( "--dryrun", help="Run the pipeline without actually executing any steps and inspect the configuration.", @@ -68,17 +81,6 @@ help="Overlay default configuration values on a provided configuration file (--config).", action="store_true", ) - parser.add_argument( - "--contextId", - help="Context id to activate or deactivate.", - type=str - ) - parser.add_argument( - "--contextOperation", - help="Context operation activate or deactivate.", - # Only required if contextId is provided - type=str - ) args = parser.parse_args() @@ -98,6 +100,6 @@ init=args.init or False, overlay_defaults=args.overlay_defaults or False, cli=True, - contextId=args.contextId, - contextOperation=args.contextOperation, + context_id=args.context_id, + context_operation=args.context_operation, ) diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 99471c9b6a..cc557a890a 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -73,8 +73,8 @@ def redact_dict(input: dict) -> dict: def index_cli( root: str, init: bool, - contextOperation: str | None, - contextId: str | None, + context_operation: str | None, + context_id: str | None, verbose: bool, resume: str | None, memprofile: bool, @@ -101,13 +101,13 @@ def index_cli( pipeline_config: str | PipelineConfig = config or _create_default_config( root, None, verbose, dryrun or False, progress_reporter ) - if contextId: - if not is_valid_guid(contextId): + if context_id: + if not is_valid_guid(context_id): ValueError("ContextId is invalid: It should be a valid Guid") - if (contextOperation != ContextSwitchType.Activate or contextOperation != ContextSwitchType.Deactivate): + if (context_operation != ContextSwitchType.Activate and context_operation != ContextSwitchType.Deactivate): ValueError("ContextOperation is invalid: It should be Active or DeActive") graphrag_config = _read_config_parameters(root, config, progress_reporter) - _switch_context(graphrag_config, contextOperation, contextId, progress_reporter) + _switch_context(graphrag_config, context_operation, context_id, progress_reporter) sys.exit(0) cache = NoopPipelineCache() if nocache else None pipeline_emit = emit.split(",") if emit else None From 5f80f666a3c0c4c4457d41a43cb3ed65e09f048c Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Tue, 20 Aug 2024 16:55:01 -0700 Subject: [PATCH 05/87] addressed comment --- graphrag/index/storage/blob_pipeline_storage.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphrag/index/storage/blob_pipeline_storage.py b/graphrag/index/storage/blob_pipeline_storage.py index e68a86378b..8501734f6d 100644 --- a/graphrag/index/storage/blob_pipeline_storage.py +++ b/graphrag/index/storage/blob_pipeline_storage.py @@ -55,6 +55,7 @@ def __init__( self._encoding = encoding or "utf-8" self._container_name = container_name self._connection_string = connection_string + self._overwrite = overwrite self._path_prefix = path_prefix or "" self._storage_account_blob_url = storage_account_blob_url self._storage_account_name = ( @@ -197,7 +198,7 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: self._container_name ) blob_client = container_client.get_blob_client(key) - if blob_client.exists(): + if blob_client.exists() and not self._overwrite: ValueError("Artifacts already exists, make sure output folder is empty.") if isinstance(value, bytes): blob_client.upload_blob(value, overwrite=True) From d38a0261fd932d33d1bcacd0a7ef76f718bdb8a9 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Wed, 7 Aug 2024 08:20:52 -0700 Subject: [PATCH 06/87] Adding azure-kusto-data as a dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7734066dc5..90af801003 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ datashaper = "^0.0.49" # Vector Stores azure-search-documents = "^11.4.0" lancedb = "^0.11.0" +azure-kusto-data = "^4.5.1" # Event Loops uvloop = { version = "^0.19.0", markers = "platform_system != 'Windows'" } From a4efadb6b308a6d3c9efea7e230385430ab57684 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Wed, 7 Aug 2024 09:03:30 -0700 Subject: [PATCH 07/87] Initial add of Kusto related file changes. --- graphrag/query/cli.py | 1 + graphrag/vector_stores/kusto.py | 116 +++++++++++++++++++++++++++++++ graphrag/vector_stores/typing.py | 10 +-- 3 files changed, 122 insertions(+), 5 deletions(-) create mode 100644 graphrag/vector_stores/kusto.py diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 7c38ad8733..fc7de08d45 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -24,6 +24,7 @@ ) from graphrag.vector_stores import VectorStoreFactory, VectorStoreType from graphrag.vector_stores.lancedb import LanceDBVectorStore +from graphrag.vector_stores.kusto import KustoVectorStore from graphrag.common.blob_storage_client import BlobStorageClient from .factories import get_global_search_engine, get_local_search_engine diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py new file mode 100644 index 0000000000..178d4b5266 --- /dev/null +++ b/graphrag/vector_stores/kusto.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Azure Kusto vector storage implementation package.""" + +from azure.kusto.data import KustoClient, KustoConnectionStringBuilder +from azure.kusto.data.helpers import dataframe_from_result_table +from graphrag.model.types import TextEmbedder + +import json +from typing import Any, List + +from .base import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + + +class KustoVectorStore(BaseVectorStore): + """The Azure Kusto vector storage implementation.""" + + def connect(self, **kwargs: Any) -> Any: + """Connect to the vector storage.""" + cluster = kwargs.get("cluster") + database = kwargs.get("database") + client_id = kwargs.get("client_id") + client_secret = kwargs.get("client_secret") + authority_id = kwargs.get("authority_id") + + kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( + cluster, client_id, client_secret, authority_id + ) + self.client = KustoClient(kcsb) + self.database = database + + def load_documents( + self, documents: List[VectorStoreDocument], overwrite: bool = True + ) -> None: + """Load documents into vector storage.""" + data = [ + { + "id": document.id, + "text": document.text, + "vector": document.vector, + "attributes": json.dumps(document.attributes), + } + for document in documents + if document.vector is not None + ] + + if len(data) == 0: + return + + # Convert data to DataFrame + import pandas as pd + df = pd.DataFrame(data) + + # Create or replace table + if overwrite: + command = f".drop table {self.collection_name} ifexists; .create table {self.collection_name} (id: string, text: string, vector: dynamic, attributes: string)" + self.client.execute(self.database, command) + + # Ingest data + ingestion_command = f".ingest inline into table {self.collection_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + def filter_by_id(self, include_ids: List[str] | List[int]) -> Any: + """Build a query filter to filter documents by id.""" + if len(include_ids) == 0: + self.query_filter = None + else: + if isinstance(include_ids[0], str): + id_filter = ", ".join([f"'{id}'" for id in include_ids]) + self.query_filter = f"id in ({id_filter})" + else: + self.query_filter = ( + f"id in ({', '.join([str(id) for id in include_ids])})" + ) + return self.query_filter + + def similarity_search_by_vector( + self, query_embedding: List[float], k: int = 10, **kwargs: Any + ) -> List[VectorStoreSearchResult]: + """Perform a vector-based similarity search.""" + query = f""" + let query_vector = dynamic({query_embedding}); + {self.collection_name} + | extend distance = array_length(set_difference(vector, query_vector)) + | where distance <= {k} + | top {k} by distance asc + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=row["id"], + text=row["text"], + vector=row["vector"], + attributes=json.loads(row["attributes"]), + ), + score=1 - abs(float(row["distance"])), + ) + for _, row in df.iterrows() + ] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> List[VectorStoreSearchResult]: + """Perform a similarity search using a given input text.""" + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector(query_embedding, k) + return [] diff --git a/graphrag/vector_stores/typing.py b/graphrag/vector_stores/typing.py index b2b3386932..459d5b5f56 100644 --- a/graphrag/vector_stores/typing.py +++ b/graphrag/vector_stores/typing.py @@ -8,7 +8,7 @@ from .azure_ai_search import AzureAISearch from .lancedb import LanceDBVectorStore -from .kustodb import KustoDBVectorStore +from .kusto import KustoVectorStore class VectorStoreType(str, Enum): @@ -16,7 +16,7 @@ class VectorStoreType(str, Enum): LanceDB = "lancedb" AzureAISearch = "azure_ai_search" - KustoDB = "kustodb" + Kusto = "kusto" class VectorStoreFactory: @@ -32,15 +32,15 @@ def register(cls, vector_store_type: str, vector_store: type): @classmethod def get_vector_store( cls, vector_store_type: VectorStoreType | str, kwargs: dict - ) -> LanceDBVectorStore | AzureAISearch | KustoDBVectorStore: + ) -> LanceDBVectorStore | AzureAISearch | KustoVectorStore: """Get the vector store type from a string.""" match vector_store_type: case VectorStoreType.LanceDB: return LanceDBVectorStore(**kwargs) case VectorStoreType.AzureAISearch: return AzureAISearch(**kwargs) - case VectorStoreType.KustoDB: - return KustoDBVectorStore() # KustoDB: Pass required arguments here + case VectorStoreType.Kusto: + return KustoVectorStore(**kwargs) case _: if vector_store_type in cls.vector_store_types: return cls.vector_store_types[vector_store_type](**kwargs) From da4ca97538a7d1cd495e2c767d025a878bcda58d Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Wed, 7 Aug 2024 11:28:57 -0700 Subject: [PATCH 08/87] Drop & Remove db query split due to syntax error --- graphrag/vector_stores/kusto.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 178d4b5266..3d6504e353 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -58,7 +58,9 @@ def load_documents( # Create or replace table if overwrite: - command = f".drop table {self.collection_name} ifexists; .create table {self.collection_name} (id: string, text: string, vector: dynamic, attributes: string)" + command = f".drop table {self.collection_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.collection_name} (id: string, text: string, vector: dynamic, attributes: string)" self.client.execute(self.database, command) # Ingest data From c4087992fb630870581877d49d87f9c3e6205e9b Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Wed, 7 Aug 2024 11:31:25 -0700 Subject: [PATCH 09/87] Adding kusto documentation --- graphrag/vector_stores/kusto.py | 58 ++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 3d6504e353..d21d1cbea1 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -21,7 +21,20 @@ class KustoVectorStore(BaseVectorStore): """The Azure Kusto vector storage implementation.""" def connect(self, **kwargs: Any) -> Any: - """Connect to the vector storage.""" + """ + Connect to the vector storage. + + Args: + **kwargs: Arbitrary keyword arguments containing connection parameters. + - cluster (str): The Kusto cluster URL. + - database (str): The Kusto database name. + - client_id (str): The client ID for AAD authentication. + - client_secret (str): The client secret for AAD authentication. + - authority_id (str): The authority ID (tenant ID) for AAD authentication. + + Returns: + Any: The Kusto client instance. + """ cluster = kwargs.get("cluster") database = kwargs.get("database") client_id = kwargs.get("client_id") @@ -37,7 +50,13 @@ def connect(self, **kwargs: Any) -> Any: def load_documents( self, documents: List[VectorStoreDocument], overwrite: bool = True ) -> None: - """Load documents into vector storage.""" + """ + Load documents into vector storage. + + Args: + documents (List[VectorStoreDocument]): List of documents to be loaded. + overwrite (bool): Whether to overwrite the existing table. Defaults to True. + """ data = [ { "id": document.id, @@ -68,7 +87,15 @@ def load_documents( self.client.execute(self.database, ingestion_command) def filter_by_id(self, include_ids: List[str] | List[int]) -> Any: - """Build a query filter to filter documents by id.""" + """ + Build a query filter to filter documents by id. + + Args: + include_ids (List[str] | List[int]): List of document IDs to include in the filter. + + Returns: + Any: The query filter string. + """ if len(include_ids) == 0: self.query_filter = None else: @@ -84,7 +111,17 @@ def filter_by_id(self, include_ids: List[str] | List[int]) -> Any: def similarity_search_by_vector( self, query_embedding: List[float], k: int = 10, **kwargs: Any ) -> List[VectorStoreSearchResult]: - """Perform a vector-based similarity search.""" + """ + Perform a vector-based similarity search. + + Args: + query_embedding (List[float]): The query embedding vector. + k (int): The number of top results to return. Defaults to 10. + **kwargs: Additional keyword arguments. + + Returns: + List[VectorStoreSearchResult]: List of search results. + """ query = f""" let query_vector = dynamic({query_embedding}); {self.collection_name} @@ -111,7 +148,18 @@ def similarity_search_by_vector( def similarity_search_by_text( self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any ) -> List[VectorStoreSearchResult]: - """Perform a similarity search using a given input text.""" + """ + Perform a similarity search using a given input text. + + Args: + text (str): The input text to search for. + text_embedder (TextEmbedder): The text embedder to convert text to vector. + k (int): The number of top results to return. Defaults to 10. + **kwargs: Additional keyword arguments. + + Returns: + List[VectorStoreSearchResult]: List of search results. + """ query_embedding = text_embedder(text) if query_embedding: return self.similarity_search_by_vector(query_embedding, k) From 303e9c479c551052d8a99149558307083ac47bc7 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Fri, 9 Aug 2024 13:42:59 -0700 Subject: [PATCH 10/87] I cleaned up the code I was working on and added TODOs to blocks of functionality I knew needed work. I am basing this off of my previous KUSTO work. --- graphrag/query/__main__.py | 22 +++++- graphrag/query/cli.py | 89 +++++++++++++++++++++++ graphrag/vector_stores/azure_ai_search.py | 4 + graphrag/vector_stores/base.py | 4 + graphrag/vector_stores/kusto.py | 41 ++++++++++- graphrag/vector_stores/lancedb.py | 3 + 6 files changed, 160 insertions(+), 3 deletions(-) diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index d572194c80..c6b34a78f4 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -6,7 +6,7 @@ import argparse from enum import Enum -from .cli import run_global_search, run_local_search +from .cli import run_global_search, run_local_search, run_content_store_local_search, run_content_store_global_search INVALID_METHOD_ERROR = "Invalid method" @@ -16,6 +16,8 @@ class SearchType(Enum): LOCAL = "local" GLOBAL = "global" + CONTENT_STORE_LOCAL = "content_store_local" + CONTENT_STORE_GLOBAL = "content_store_global" def __str__(self): """Return the string representation of the enum value.""" @@ -108,5 +110,23 @@ def __str__(self): args.context_id, args.query[0], ) + case SearchType.CONTENT_STORE_LOCAL: + run_content_store_local_search( + args.config, + args.data, + args.root, + args.community_level, + args.response_type, + args.query[0], + ) + case SearchType.CONTENT_STORE_GLOBAL: + run_content_store_global_search( + args.config, + args.data, + args.root, + args.community_level, + args.response_type, + args.query[0], + ) case _: raise ValueError(INVALID_METHOD_ERROR) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index fc7de08d45..79b03a7ee8 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -23,6 +23,7 @@ store_entity_semantic_embeddings, ) from graphrag.vector_stores import VectorStoreFactory, VectorStoreType +from graphrag.vector_stores.base import BaseVectorStore from graphrag.vector_stores.lancedb import LanceDBVectorStore from graphrag.vector_stores.kusto import KustoVectorStore from graphrag.common.blob_storage_client import BlobStorageClient @@ -235,6 +236,94 @@ def read_paraquet_file(config:GraphRagConfig, path: str, storageType: StorageTyp if not file_path.exists(): raise ValueError(f"Data path {file_path} does not exist.") return pd.read_parquet(path) +# TODO I split this out for now to preserve how the original local search worked. +# I don't think this will necessarily be permanently separate. +# It was just easier without having to keep everything generic and work the same way as local search worked. +# One last optimization: Once all the merges are done we can go back to the parquet loads and optimize those for only the fields we need and merge them right away into one big table (I think). +def run_content_store_local_search( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + query: str, +): + """Run a local search with the given query.""" + data_dir, root_dir, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + data_path = Path(data_dir) + + vector_store_args = ( + config.embeddings.vector_store if config.embeddings.vector_store else {} + ) + + vector_store_type = vector_store_args.get("type", VectorStoreType.Kusto) + + collection_name = vector_store_args.get( + "query_collection_name", "entity_description_embeddings" + ) + vector_store_args.update({"collection_name": collection_name}) + + description_embedding_store = VectorStoreFactory.get_vector_store( + vector_store_type=vector_store_type, kwargs=vector_store_args + ) + + description_embedding_store.connect(**vector_store_args) + + #TODO add back covariates. I skipped this for now. + description_embedding_store.load_parqs(data_dir, ["create_final_nodes", "create_final_community_reports", "create_final_text_units", "create_final_relationships", "create_final_entities"]) + + #TODO KQLify this. This merge of nodes & entities needs to happen in Kusto. + # entities = read_indexer_entities(final_nodes, final_entities, community_level) + # description_embedding_store = __get_embedding_description_store( + # entities=entities, + # description_embedding_store=description_embedding_store, + # config_args=vector_store_args, + # ) + + #TODO add back covariates w/Kusto. I skipped this for now. + # covariates = ( + # read_indexer_covariates(final_covariates) + # if final_covariates is not None + # else [] + # ) + + #TODO KQLify this. I know at least the read_indedxer_reports needs to be done in Kusto. We are joining the community reports & final nodes. + # search_engine = get_local_search_engine( + # config, + # reports=read_indexer_reports( + # final_community_reports, final_nodes, community_level + # ), + # text_units=read_indexer_text_units(final_text_units), + # entities=entities, + # relationships=read_indexer_relationships(final_relationships), + # covariates={"claims": covariates}, + # description_embedding_store=description_embedding_store, + # response_type=response_type, + # ) + + #TODO This is the biggest TODO. I need to go through the whole mixed_context.py and make sure it's using Kusto data not the parquet data it expects in memory. + # result = search_engine.search(query=query) + # reporter.success(f"Local Search Response: {result.response}") + # return result.response + + return True #Obviously this is a placeholder due to all the TODOs above. + + + +def run_content_store_global_search( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + query: str, +): + """Run a content store global search with the given query.""" + raise NotImplementedError("This function is not implemented yet.") + + def _configure_paths_and_settings( diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index c503fca1c1..f0556b6d2b 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -192,3 +192,7 @@ def similarity_search_by_text( query_embedding=query_embedding, k=k ) return [] + + def load_parqs(self, data_path, parq_names) -> Any: + """Load documents (Parquet files) into the vector-store.""" + raise NotImplementedError("Loading Parquet files is not supported for Azure AI Search") \ No newline at end of file diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index 38f0e584fc..c1460ed628 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -79,3 +79,7 @@ def similarity_search_by_text( @abstractmethod def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: """Build a query filter to filter documents by id.""" + + @abstractmethod + def load_parqs(self, data_path: str, parqs: list[str]) -> Any: + """Load documents (Parquet files) into the vector-store.""" diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index d21d1cbea1..dff0d317a3 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -3,12 +3,16 @@ """The Azure Kusto vector storage implementation package.""" +import typing from azure.kusto.data import KustoClient, KustoConnectionStringBuilder from azure.kusto.data.helpers import dataframe_from_result_table from graphrag.model.types import TextEmbedder +import pandas as pd +from pathlib import Path + import json -from typing import Any, List +from typing import Any, List, cast from .base import ( BaseVectorStore, @@ -20,6 +24,15 @@ class KustoVectorStore(BaseVectorStore): """The Azure Kusto vector storage implementation.""" + #TODO Currently loading in all the parquet fields, need to filter out the ones that are not needed. + #TODO Double check the types. This was done quickly and I may have missed something. + #TODO Check if there is a better way to get the fields to ingest into the Kusto table. These schemas are based off of me reading the files and manually making them. Maybe there is a better way to do this. + schema_dict: typing.ClassVar[dict] = {"create_final_nodes": "(level: int, title: string, type: string, description: string, source_id: string, community: int, degree: int, human_readable_id: int, id: string, size: int, graph_embedding: dynamic, entity_type: string, top_level_node_id: string, x: int, y: int)" + , "create_final_community_reports": "(community: int, full_content: string, level: int, rank: int, title: string, rank_explanation: string, summary: string, findings: string, full_content_json: string, id: string)" + , "create_final_text_units": "(id: string, text: string, n_tokens: int, document_ids: string, entity_ids: string, relationship_ids: string)" + , "create_final_relationships": "(source: string, target: string, weight: float, description: string, text_unit_ids: string, id: string, human_readable_id: string, source_degree: int, target_degree: int, rank: int)" + , "create_final_entities": "(id: string, name: string, type: string, description: string, human_readable_id: int, graph_embedding: dynamic, text_unit_ids: string)"} + def connect(self, **kwargs: Any) -> Any: """ Connect to the vector storage. @@ -42,7 +55,7 @@ def connect(self, **kwargs: Any) -> Any: authority_id = kwargs.get("authority_id") kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( - cluster, client_id, client_secret, authority_id + str(cluster), str(client_id), str(client_secret), str(authority_id) ) self.client = KustoClient(kcsb) self.database = database @@ -164,3 +177,27 @@ def similarity_search_by_text( if query_embedding: return self.similarity_search_by_vector(query_embedding, k) return [] + + + def load_parqs(self, data_dir, parq_names) -> Any: + data_path = Path(data_dir) + for parq_name in parq_names: + parq_path = data_path / f"{parq_name}.parquet" + if parq_path.exists(): + parq = pd.read_parquet(parq_path) + + # I wasn't sure if was easier to rename the columns here or in the KQL queries. + # Most likely the KQL queries as this is a place I am trying to handle all the parquet files generically. + # parq.rename(columns={"id": "title"}, inplace=True) + # parq = cast(pd.DataFrame, parq[["title", "degree", "community"]]).rename( + # columns={"title": "name", "degree": "rank"} + # ) + + command = f".drop table {parq_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {parq_name} {self.schema_dict[parq_name]}" + self.client.execute(self.database, command) + command = f".ingest inline into table {parq_name} <| {parq.to_csv(index=False, header=False)}" + self.client.execute(self.database, command) + else: + print(f"Parquet file {parq_path} not found.") \ No newline at end of file diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 0c9ea17f54..22d04d036d 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -119,3 +119,6 @@ def similarity_search_by_text( if query_embedding: return self.similarity_search_by_vector(query_embedding, k) return [] + + def load_parqs(self, data_path, parq_names) -> Any: + raise NotImplementedError("Loading Parquet files is not supported for LanceDB") \ No newline at end of file From 16295e5da8a3e66178c9196a6770b9bf08a8add7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Wed, 21 Aug 2024 10:41:17 -0700 Subject: [PATCH 11/87] Add Reading from graphdb --- common/graph_db_client.py | 129 +++++++++++++++++++++++ graphrag/index/create_pipeline_config.py | 1 + graphrag/index/emit/graph_db_emitter.py | 28 +++++ graphrag/query/cli.py | 15 ++- 4 files changed, 165 insertions(+), 8 deletions(-) create mode 100644 common/graph_db_client.py create mode 100644 graphrag/index/emit/graph_db_emitter.py diff --git a/common/graph_db_client.py b/common/graph_db_client.py new file mode 100644 index 0000000000..31f21c813f --- /dev/null +++ b/common/graph_db_client.py @@ -0,0 +1,129 @@ +import os +import pandas as pd + +import numpy as np + +import ast + +from gremlin_python.driver import client, serializer + +import time +import os +import json + +class GraphDBClient: + def __init__(self): + ACCOUNT_NAME = os.getenv("ACCOUNT_NAME") + ACCOUNT_KEY = os.getenv("ACCOUNT_KEY") + GRAPHDB_USERNAME = os.getenv("GRAPHDB_USERNAME") + self._client=client.Client( + url=f"wss://{ACCOUNT_NAME}.gremlin.cosmos.azure.com:443/", + traversal_source="g", + username=GRAPHDB_USERNAME, + password=f"{ACCOUNT_KEY}", + message_serializer=serializer.GraphSONSerializersV2d0(), + ) + + def result_to_df(self,result) -> pd.DataFrame: + json_data = [] + for row in result: + json_row = row[0] + properties_dict = json_row.pop('properties') + formatted_properties={} + for k,v in properties_dict.items(): + new_val=v + if isinstance(v,list) and isinstance(v[0],dict): + new_val=v[0]['value'] + if k=='description_embedding' or k =='text_unit_ids' or k=='graph_embedding': + new_val=ast.literal_eval(new_val) + if isinstance(new_val,list): + new_val=np.array(new_val) + formatted_properties[k]=new_val + json_row.update(formatted_properties) + json_data.append(json_row) + df = pd.DataFrame(json_data) + return df + + def query_vertices(self) -> pd.DataFrame: + result = self._client.submit( + message=( + "g.V()" + ), + ) + return self.result_to_df(result) + + def query_edges(self) -> pd.DataFrame: + result = self._client.submit( + message=( + "g.E()" + ), + ) + return self.result_to_df(result) + + def write_vertices(self,data: pd.DataFrame)->None: + for row in data.itertuples(): + print(row.id) + self._client.submit( + message=( + "g.addV('entity')" + ".property('id', prop_id)" + ".property('name', prop_name)" + ".property('type', prop_type)" + ".property('description','prop_description')" + ".property('human_readable_id', prop_human_readable_id)" + ".property('category', prop_partition_key)" + ".property(list,'description_embedding',prop_description_embedding)" + ".property(list,'graph_embedding',prop_graph_embedding)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ), + bindings={ + "prop_id": row.id, + "prop_name": row.name, + "prop_type": row.type, + "prop_description": row.description, + "prop_human_readable_id": row.human_readable_id, + "prop_partition_key": "entities", + "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), + "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), + "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), + }, + ) + time.sleep(5) + + + def write_edges(self,data: pd.DataFrame)->None: + for row in data.itertuples(): + print(row.source,row.target) + self._client.submit( + message=( + "g.V().has('name',prop_source_id)" + ".addE('connects')" + ".to(g.V().has('name',prop_target_id))" + ".property('weight',prop_weight)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ".property('description',prop_description)" + ".property('id',prop_id)" + ".property('human_readable_id',prop_human_readable_id)" + ".property('source_degree',prop_source_degree)" + ".property('target_degree',prop_target_degree)" + ".property('rank',prop_rank)" + ".property('source',prop_source)" + ".property('target',prop_target)" + ), + bindings={ + "prop_partition_key": "entities", + "prop_source_id": row.source, + "prop_target_id": row.target, + "prop_weight": row.weight, + "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), + "prop_description": row.description, + "prop_id": row.id, + "prop_human_readable_id": row.human_readable_id, + "prop_source_degree": row.source_degree, + "prop_target_degree": row.target_degree, + "prop_rank": row.rank, + "prop_source": row.source, + "prop_target": row.target, + }, + ) + time.sleep(5) \ No newline at end of file diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index 4472ef8828..439f64cd75 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -370,6 +370,7 @@ def _graph_workflows( }, ), "skip_description_embedding": skip_relationship_description_embedding, + "emitter_type": TableEmitterType.Graphdb, }, ), PipelineWorkflowReference( diff --git a/graphrag/index/emit/graph_db_emitter.py b/graphrag/index/emit/graph_db_emitter.py new file mode 100644 index 0000000000..aa8ed7cfb7 --- /dev/null +++ b/graphrag/index/emit/graph_db_emitter.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphDBEmitter module.""" + +import logging +import traceback + +import pandas as pd + +from gremlin_python.driver import client, serializer + +from .table_emitter import TableEmitter + +from common.graph_db_client import GraphDBClient + +class GraphDBEmitter(TableEmitter): + def __init__(self): + self.graph_db_client = GraphDBClient() + self.allowed_workflows = ['create_final_entities','create_final_relationships'] + + async def emit(self, name: str, data: pd.DataFrame) -> None: + if name not in self.allowed_workflows: + return + if name == 'create_final_entities': + self.graph_db_client.write_vertices(data) + if name == 'create_final_relationships': + self.graph_db_client.write_edges(data) \ No newline at end of file diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 81efbb550b..970ace67ce 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -30,6 +30,8 @@ read_indexer_text_units, ) +from common.graph_db_client import GraphDBClient + reporter = PrintProgressReporter("") @@ -92,14 +94,13 @@ def run_global_search( data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config_dir ) + graph_db_client = GraphDBClient() data_path = Path(data_dir) final_nodes: pd.DataFrame = pd.read_parquet( data_path / "create_final_nodes.parquet" ) - final_entities: pd.DataFrame = pd.read_parquet( - data_path / "create_final_entities.parquet" - ) + final_entities = graph_db_client.query_vertices() final_community_reports: pd.DataFrame = pd.read_parquet( data_path / "create_final_community_reports.parquet" ) @@ -134,16 +135,14 @@ def run_local_search( data_dir, root_dir, config_dir ) data_path = Path(data_dir) - + graph_db_client = GraphDBClient() final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet") final_community_reports = pd.read_parquet( data_path / "create_final_community_reports.parquet" ) final_text_units = pd.read_parquet(data_path / "create_final_text_units.parquet") - final_relationships = pd.read_parquet( - data_path / "create_final_relationships.parquet" - ) - final_entities = pd.read_parquet(data_path / "create_final_entities.parquet") + final_relationships = graph_db_client.query_edges() + final_entities = graph_db_client.query_vertices() final_covariates_path = data_path / "create_final_covariates.parquet" final_covariates = ( pd.read_parquet(final_covariates_path) From 01fb61753c9749173f0d0983e7b5868cc69c8e36 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Wed, 21 Aug 2024 10:58:48 -0700 Subject: [PATCH 12/87] Some modifications to kusto flow. --- .vscode/launch.json | 29 +-------------------- graphrag/query/cli.py | 3 +++ graphrag/query/indexer_adapters.py | 17 +++++++++++++ graphrag/vector_stores/kusto.py | 2 +- graphrag/vector_stores/kustodb.py | 41 ------------------------------ 5 files changed, 22 insertions(+), 70 deletions(-) delete mode 100644 graphrag/vector_stores/kustodb.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 07d7ab4afc..ba22c0423f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -2,34 +2,7 @@ "version": "0.2.0", "configurations": [ { - "name": "Indexer", - "type": "debugpy", - "python": "E:\\graphrag\\.venv\\Scripts\\python.exe", - "request": "launch", - "cwd": "${workspaceFolder}", - "module": "poetry", - "args": ["poe", "index", "","", "--root", ".\\ragtest6"], - "stopOnEntry": false - }, - { - "name": "Query", - "type": "debugpy", - "python": "E:\\graphrag\\.venv\\Scripts\\python.exe", - "request": "launch", - "cwd": "${workspaceFolder}", - "module": "poetry", - "args": ["poe", "query", "--root", ".\\ragtest6", "--method", "local", "How Guillermo got certificate access?"], - "stopOnEntry": false - }, - { - "name": "Context Switch", - "type": "debugpy", - "python": "E:\\graphrag\\.venv\\Scripts\\python.exe", - "request": "launch", - "cwd": "${workspaceFolder}", - "module": "poetry", - "args": ["poe", "index", "--context-id", "f367f396-b0bf-4547-8d9c-331fc7a39433", "--context-operation", "activate", "--root", ".\\ragtest6"], - "stopOnEntry": false + } ] } \ No newline at end of file diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 79b03a7ee8..cd2605db36 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -34,6 +34,7 @@ read_indexer_entities, read_indexer_relationships, read_indexer_reports, + kt_read_indexer_reports, read_indexer_text_units, ) @@ -289,6 +290,8 @@ def run_content_store_local_search( # else [] # ) + reports=kt_read_indexer_reports( description_embedding_store, community_level) + #TODO KQLify this. I know at least the read_indedxer_reports needs to be done in Kusto. We are joining the community reports & final nodes. # search_engine = get_local_search_engine( # config, diff --git a/graphrag/query/indexer_adapters.py b/graphrag/query/indexer_adapters.py index 132db92c1f..c3b5f6e1ab 100644 --- a/graphrag/query/indexer_adapters.py +++ b/graphrag/query/indexer_adapters.py @@ -19,6 +19,7 @@ read_text_units, ) +from graphrag.vector_stores import VectorStoreFactory, VectorStoreType def read_indexer_text_units(final_text_units: pd.DataFrame) -> list[TextUnit]: """Read in the Text Units from the raw indexing outputs.""" @@ -58,6 +59,22 @@ def read_indexer_relationships(final_relationships: pd.DataFrame) -> list[Relati attributes_cols=["rank"], ) +def kt_read_indexer_reports( + vs: VectorStoreType.Kusto, + community_level: int, +) -> list[CommunityReport]: + + vs.client.execute(vs.database,'.drop table interm_rep ifexists') + + cmd=f''' + .set interm_rep <| (create_final_community_reports | where level <= 2 | + join kind=inner (create_final_nodes | + where level <= 2 | summarize community=max(community) by ['title'] | summarize by community ) + on community | project-away community1) + ''' + + res=vs.client.execute(vs.database,cmd) + return True #TODO: error checking should be added later def read_indexer_reports( final_community_reports: pd.DataFrame, diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index dff0d317a3..1a856724e7 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -30,7 +30,7 @@ class KustoVectorStore(BaseVectorStore): schema_dict: typing.ClassVar[dict] = {"create_final_nodes": "(level: int, title: string, type: string, description: string, source_id: string, community: int, degree: int, human_readable_id: int, id: string, size: int, graph_embedding: dynamic, entity_type: string, top_level_node_id: string, x: int, y: int)" , "create_final_community_reports": "(community: int, full_content: string, level: int, rank: int, title: string, rank_explanation: string, summary: string, findings: string, full_content_json: string, id: string)" , "create_final_text_units": "(id: string, text: string, n_tokens: int, document_ids: string, entity_ids: string, relationship_ids: string)" - , "create_final_relationships": "(source: string, target: string, weight: float, description: string, text_unit_ids: string, id: string, human_readable_id: string, source_degree: int, target_degree: int, rank: int)" + , "create_final_relationships": "(source: string, target: string, weight: real, description: string, text_unit_ids: string, id: string, human_readable_id: string, source_degree: int, target_degree: int, rank: int)" , "create_final_entities": "(id: string, name: string, type: string, description: string, human_readable_id: int, graph_embedding: dynamic, text_unit_ids: string)"} def connect(self, **kwargs: Any) -> Any: diff --git a/graphrag/vector_stores/kustodb.py b/graphrag/vector_stores/kustodb.py deleted file mode 100644 index b009e2cbd6..0000000000 --- a/graphrag/vector_stores/kustodb.py +++ /dev/null @@ -1,41 +0,0 @@ -# write kusto db here. -import lancedb as lancedb # noqa: I001 (Ruff was breaking on this file imports, even tho they were sorted and passed local tests) -from graphrag.model.types import TextEmbedder -from typing import Any - -from .base import ( - BaseVectorStore, - VectorStoreDocument, - VectorStoreSearchResult, -) -class KustoDBVectorStore(BaseVectorStore): - """Kusto vector store.""" - - def __init__(self, **kwargs): - """Initialize the Kusto vector store.""" - pass - - def connect(self, **kwargs: Any) -> Any: - """Connect to the vector storage.""" - pass - - def load_documents( - self, documents: list[VectorStoreDocument], overwrite: bool = True - ) -> None: - """Load documents into vector storage.""" - pass - - def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: - """Build a query filter to filter documents by id.""" - pass - - def similarity_search_by_vector( - self, query_embedding: list[float], k: int = 10, **kwargs: Any - ) -> list[VectorStoreSearchResult]: - """Perform a vector-based similarity search.""" - return [] - def similarity_search_by_text( - self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any - ) -> list[VectorStoreSearchResult]: - """Perform a similarity search using a given input text.""" - return [] From f78bd407e127dd79ae875671a232aa6d2887ad86 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Wed, 21 Aug 2024 11:30:58 -0700 Subject: [PATCH 13/87] Add a skleton for launch.json --- .vscode/launch.json | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index ba22c0423f..1d2c17bcea 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -2,7 +2,14 @@ "version": "0.2.0", "configurations": [ { - + "name": "", + "type": "debugpy", + "python": "", + "request": "launch", + "cwd": "${workspaceFolder}", + "module": "poetry", + "args": ["poe", "", "other args"], + "stopOnEntry": false } ] } \ No newline at end of file From be00a937ca740c8e4fc9c39a83d1add85fa7654c Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Wed, 21 Aug 2024 11:34:17 -0700 Subject: [PATCH 14/87] Fix a typo --- .vscode/launch.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 1d2c17bcea..95819f5f64 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -2,7 +2,7 @@ "version": "0.2.0", "configurations": [ { - "name": "", + "name": "", "type": "debugpy", "python": "", "request": "launch", From 09181864290c694c05fe23604533a848de12afe5 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Wed, 21 Aug 2024 12:02:44 -0700 Subject: [PATCH 15/87] Creating entities table in Kusto. --- graphrag/query/cli.py | 10 ++++++++-- graphrag/vector_stores/azure_ai_search.py | 6 ++++-- graphrag/vector_stores/base.py | 5 +++++ graphrag/vector_stores/kusto.py | 2 ++ graphrag/vector_stores/lancedb.py | 5 ++++- 5 files changed, 23 insertions(+), 5 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index cd2605db36..d3aa53fe4c 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -276,7 +276,7 @@ def run_content_store_local_search( description_embedding_store.load_parqs(data_dir, ["create_final_nodes", "create_final_community_reports", "create_final_text_units", "create_final_relationships", "create_final_entities"]) #TODO KQLify this. This merge of nodes & entities needs to happen in Kusto. - # entities = read_indexer_entities(final_nodes, final_entities, community_level) + create_entities_table(description_embedding_store, community_level) # description_embedding_store = __get_embedding_description_store( # entities=entities, # description_embedding_store=description_embedding_store, @@ -313,7 +313,13 @@ def run_content_store_local_search( return True #Obviously this is a placeholder due to all the TODOs above. - +# Create entities table similar to read_indexer_entities, but creating that table in Kusto, not in memory. +def create_entities_table(description_embedding_store: BaseVectorStore, community_level: int): + description_embedding_store.execute_query(f".set-or-replace entities <| create_final_nodes \ + | where level <= {community_level} \ + | project community=coalesce(community, 0), name=['title'], rank=degree \ + | summarize community=max(community) by name, rank \ + | join kind=inner create_final_entities on name") def run_content_store_global_search( config_dir: str | None, diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index f0556b6d2b..cd78078d35 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -194,5 +194,7 @@ def similarity_search_by_text( return [] def load_parqs(self, data_path, parq_names) -> Any: - """Load documents (Parquet files) into the vector-store.""" - raise NotImplementedError("Loading Parquet files is not supported for Azure AI Search") \ No newline at end of file + raise NotImplementedError("Loading Parquet files is not supported for Azure AI Search") + + def execute_query(self, query: str) -> Any: + return super().execute_query(query) \ No newline at end of file diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index c1460ed628..be7a400d5c 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -83,3 +83,8 @@ def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: @abstractmethod def load_parqs(self, data_path: str, parqs: list[str]) -> Any: """Load documents (Parquet files) into the vector-store.""" + + #TODO This is temporary until I take out the client from the vector store class + @abstractmethod + def execute_query(self, query: str) -> Any: + """Execute a query in the vector-store.""" \ No newline at end of file diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 1a856724e7..9010bd147f 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -178,6 +178,8 @@ def similarity_search_by_text( return self.similarity_search_by_vector(query_embedding, k) return [] + def execute_query(self, query: str) -> Any: + self.client.execute(self.database, f"{query}") def load_parqs(self, data_dir, parq_names) -> Any: data_path = Path(data_dir) diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 22d04d036d..ddf6effb59 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -121,4 +121,7 @@ def similarity_search_by_text( return [] def load_parqs(self, data_path, parq_names) -> Any: - raise NotImplementedError("Loading Parquet files is not supported for LanceDB") \ No newline at end of file + raise NotImplementedError("Loading Parquet files is not supported for LanceDB") + + def execute_query(self, query: str) -> Any: + return super().execute_query(query) \ No newline at end of file From a6b068941e5e90c4a73bbe297db4265c7856b3d6 Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Wed, 21 Aug 2024 15:44:54 -0700 Subject: [PATCH 16/87] Minor fixes --- graphrag/index/input/text.py | 7 +++++-- graphrag/query/cli.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/graphrag/index/input/text.py b/graphrag/index/input/text.py index eceb8bb06a..fe5ae25f8e 100644 --- a/graphrag/index/input/text.py +++ b/graphrag/index/input/text.py @@ -39,13 +39,16 @@ async def load_file( new_item["id"] = gen_md5_hash(new_item, new_item.keys()) new_item["title"] = str(Path(path).name) return new_item - + base_dir = config.base_dir + if config.type == "file": + # base dir is already being added to root dir in case of type file. + base_dir = None files = list( storage.find( re.compile(config.file_pattern), progress=progress, file_filter=config.file_filter, - base_dir=config.base_dir, + base_dir=base_dir ) ) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index d3aa53fe4c..756c5746e6 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -235,7 +235,7 @@ def read_paraquet_file(config:GraphRagConfig, path: str, storageType: StorageTyp else: file_path = Path(path) if not file_path.exists(): - raise ValueError(f"Data path {file_path} does not exist.") + return pd.DataFrame() return pd.read_parquet(path) # TODO I split this out for now to preserve how the original local search worked. # I don't think this will necessarily be permanently separate. From 116f524017da6f6ce8af8bb61cbf9563381951d7 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Thu, 22 Aug 2024 09:48:09 -0700 Subject: [PATCH 17/87] minor updates: lancedb style entities --- graphrag/query/cli.py | 7 ++++++- graphrag/vector_stores/kusto.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 756c5746e6..2128aff077 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -315,11 +315,16 @@ def run_content_store_local_search( # Create entities table similar to read_indexer_entities, but creating that table in Kusto, not in memory. def create_entities_table(description_embedding_store: BaseVectorStore, community_level: int): - description_embedding_store.execute_query(f".set-or-replace entities <| create_final_nodes \ + description_embedding_store.execute_query(".drop table entities ifexists") #make sure a stale schema doesn't exist + description_embedding_store.execute_query(".set entities <| (create_final_entities | \ + project id,title=name,text=description,vector=description_embeddings)") + ''' + description_embedding_store.execute_query(f".set entities <| create_final_nodes \ | where level <= {community_level} \ | project community=coalesce(community, 0), name=['title'], rank=degree \ | summarize community=max(community) by name, rank \ | join kind=inner create_final_entities on name") + ''' def run_content_store_global_search( config_dir: str | None, diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 9010bd147f..7ca938b113 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -31,7 +31,7 @@ class KustoVectorStore(BaseVectorStore): , "create_final_community_reports": "(community: int, full_content: string, level: int, rank: int, title: string, rank_explanation: string, summary: string, findings: string, full_content_json: string, id: string)" , "create_final_text_units": "(id: string, text: string, n_tokens: int, document_ids: string, entity_ids: string, relationship_ids: string)" , "create_final_relationships": "(source: string, target: string, weight: real, description: string, text_unit_ids: string, id: string, human_readable_id: string, source_degree: int, target_degree: int, rank: int)" - , "create_final_entities": "(id: string, name: string, type: string, description: string, human_readable_id: int, graph_embedding: dynamic, text_unit_ids: string)"} + , "create_final_entities": "(id: string, name: string, type: string, description: string, human_readable_id: int, graph_embedding: dynamic, text_unit_ids: string, description_embeddings: dynamic )"}} def connect(self, **kwargs: Any) -> Any: """ From 83798a41e601dd5628e30ecdfa3fb3b01b2f7ef2 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Thu, 22 Aug 2024 10:02:26 -0700 Subject: [PATCH 18/87] syntax --- graphrag/vector_stores/kusto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 7ca938b113..f934231c77 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -31,7 +31,7 @@ class KustoVectorStore(BaseVectorStore): , "create_final_community_reports": "(community: int, full_content: string, level: int, rank: int, title: string, rank_explanation: string, summary: string, findings: string, full_content_json: string, id: string)" , "create_final_text_units": "(id: string, text: string, n_tokens: int, document_ids: string, entity_ids: string, relationship_ids: string)" , "create_final_relationships": "(source: string, target: string, weight: real, description: string, text_unit_ids: string, id: string, human_readable_id: string, source_degree: int, target_degree: int, rank: int)" - , "create_final_entities": "(id: string, name: string, type: string, description: string, human_readable_id: int, graph_embedding: dynamic, text_unit_ids: string, description_embeddings: dynamic )"}} + , "create_final_entities": "(id: string, name: string, type: string, description: string, human_readable_id: int, graph_embedding: dynamic, text_unit_ids: string, description_embeddings: dynamic )"} def connect(self, **kwargs: Any) -> Any: """ From 02b3a959fa5f3e0ddf7ed162353452044901d7ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Tue, 20 Aug 2024 13:27:24 -0700 Subject: [PATCH 19/87] Writing edges and vertices into graphdb --- graphrag/index/create_pipeline_config.py | 2 + graphrag/index/emit/factories.py | 3 + graphrag/index/emit/types.py | 1 + poetry.lock | 459 ++++++++++++++++++++++- pyproject.toml | 1 + 5 files changed, 465 insertions(+), 1 deletion(-) diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index 22dba20029..4472ef8828 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -6,6 +6,7 @@ import json import logging from pathlib import Path +from .emit.types import TableEmitterType from graphrag.config.enums import ( CacheType, @@ -354,6 +355,7 @@ def _graph_workflows( ), "skip_name_embedding": skip_entity_name_embedding, "skip_description_embedding": skip_entity_description_embedding, + "emitter_type": TableEmitterType.Graphdb, }, ), PipelineWorkflowReference( diff --git a/graphrag/index/emit/factories.py b/graphrag/index/emit/factories.py index 84afa68443..cd9e203917 100644 --- a/graphrag/index/emit/factories.py +++ b/graphrag/index/emit/factories.py @@ -9,6 +9,7 @@ from .csv_table_emitter import CSVTableEmitter from .json_table_emitter import JsonTableEmitter from .parquet_table_emitter import ParquetTableEmitter +from .graph_db_emitter import GraphDBEmitter from .table_emitter import TableEmitter from .types import TableEmitterType @@ -24,6 +25,8 @@ def create_table_emitter( return ParquetTableEmitter(storage, on_error) case TableEmitterType.CSV: return CSVTableEmitter(storage) + case TableEmitterType.Graphdb: + return GraphDBEmitter() case _: msg = f"Unsupported table emitter type: {emitter_type}" raise ValueError(msg) diff --git a/graphrag/index/emit/types.py b/graphrag/index/emit/types.py index ab3452856f..0b0ff88541 100644 --- a/graphrag/index/emit/types.py +++ b/graphrag/index/emit/types.py @@ -12,3 +12,4 @@ class TableEmitterType(str, Enum): Json = "json" Parquet = "parquet" CSV = "csv" + Graphdb = "graphdb" diff --git a/poetry.lock b/poetry.lock index 278c8ad98b..49b84c97e0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,17 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "aenum" +version = "3.1.15" +description = "Advanced Enumerations (compatible with Python's stdlib Enum), NamedTuples, and NamedConstants" +optional = false +python-versions = "*" +files = [ + {file = "aenum-3.1.15-py2-none-any.whl", hash = "sha256:27b1710b9d084de6e2e695dab78fe9f269de924b51ae2850170ee7e1ca6288a5"}, + {file = "aenum-3.1.15-py3-none-any.whl", hash = "sha256:e0dfaeea4c2bd362144b87377e2c61d91958c5ed0b4daf89cb6f45ae23af6288"}, + {file = "aenum-3.1.15.tar.gz", hash = "sha256:8cbd76cd18c4f870ff39b24284d3ea028fbe8731a58df3aa581e434c575b9559"}, +] + [[package]] name = "aiofiles" version = "24.1.0" @@ -11,6 +23,114 @@ files = [ {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, ] +[[package]] +name = "aiohappyeyeballs" +version = "2.3.7" +description = "Happy Eyeballs for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiohappyeyeballs-2.3.7-py3-none-any.whl", hash = "sha256:337ce4dc0e99eb697c3c5a77d6cb3c52925824d9a67ac0dea7c55b8a2d60b222"}, + {file = "aiohappyeyeballs-2.3.7.tar.gz", hash = "sha256:e794cd29ba6a14078092984e43688212a19081de3a73b6796c2fdeb3706dd6ce"}, +] + +[[package]] +name = "aiohttp" +version = "3.10.3" +description = "Async http client/server framework (asyncio)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiohttp-3.10.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc36cbdedf6f259371dbbbcaae5bb0e95b879bc501668ab6306af867577eb5db"}, + {file = "aiohttp-3.10.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:85466b5a695c2a7db13eb2c200af552d13e6a9313d7fa92e4ffe04a2c0ea74c1"}, + {file = "aiohttp-3.10.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:71bb1d97bfe7e6726267cea169fdf5df7658831bb68ec02c9c6b9f3511e108bb"}, + {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baec1eb274f78b2de54471fc4c69ecbea4275965eab4b556ef7a7698dee18bf2"}, + {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13031e7ec1188274bad243255c328cc3019e36a5a907978501256000d57a7201"}, + {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2bbc55a964b8eecb341e492ae91c3bd0848324d313e1e71a27e3d96e6ee7e8e8"}, + {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8cc0564b286b625e673a2615ede60a1704d0cbbf1b24604e28c31ed37dc62aa"}, + {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f817a54059a4cfbc385a7f51696359c642088710e731e8df80d0607193ed2b73"}, + {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8542c9e5bcb2bd3115acdf5adc41cda394e7360916197805e7e32b93d821ef93"}, + {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:671efce3a4a0281060edf9a07a2f7e6230dca3a1cbc61d110eee7753d28405f7"}, + {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:0974f3b5b0132edcec92c3306f858ad4356a63d26b18021d859c9927616ebf27"}, + {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:44bb159b55926b57812dca1b21c34528e800963ffe130d08b049b2d6b994ada7"}, + {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6ae9ae382d1c9617a91647575255ad55a48bfdde34cc2185dd558ce476bf16e9"}, + {file = "aiohttp-3.10.3-cp310-cp310-win32.whl", hash = "sha256:aed12a54d4e1ee647376fa541e1b7621505001f9f939debf51397b9329fd88b9"}, + {file = "aiohttp-3.10.3-cp310-cp310-win_amd64.whl", hash = "sha256:b51aef59370baf7444de1572f7830f59ddbabd04e5292fa4218d02f085f8d299"}, + {file = "aiohttp-3.10.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e021c4c778644e8cdc09487d65564265e6b149896a17d7c0f52e9a088cc44e1b"}, + {file = "aiohttp-3.10.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:24fade6dae446b183e2410a8628b80df9b7a42205c6bfc2eff783cbeedc224a2"}, + {file = "aiohttp-3.10.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bc8e9f15939dacb0e1f2d15f9c41b786051c10472c7a926f5771e99b49a5957f"}, + {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5a9ec959b5381271c8ec9310aae1713b2aec29efa32e232e5ef7dcca0df0279"}, + {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a5d0ea8a6467b15d53b00c4e8ea8811e47c3cc1bdbc62b1aceb3076403d551f"}, + {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c9ed607dbbdd0d4d39b597e5bf6b0d40d844dfb0ac6a123ed79042ef08c1f87e"}, + {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3e66d5b506832e56add66af88c288c1d5ba0c38b535a1a59e436b300b57b23e"}, + {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fda91ad797e4914cca0afa8b6cccd5d2b3569ccc88731be202f6adce39503189"}, + {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:61ccb867b2f2f53df6598eb2a93329b5eee0b00646ee79ea67d68844747a418e"}, + {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:6d881353264e6156f215b3cb778c9ac3184f5465c2ece5e6fce82e68946868ef"}, + {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:b031ce229114825f49cec4434fa844ccb5225e266c3e146cb4bdd025a6da52f1"}, + {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5337cc742a03f9e3213b097abff8781f79de7190bbfaa987bd2b7ceb5bb0bdec"}, + {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ab3361159fd3dcd0e48bbe804006d5cfb074b382666e6c064112056eb234f1a9"}, + {file = "aiohttp-3.10.3-cp311-cp311-win32.whl", hash = "sha256:05d66203a530209cbe40f102ebaac0b2214aba2a33c075d0bf825987c36f1f0b"}, + {file = "aiohttp-3.10.3-cp311-cp311-win_amd64.whl", hash = "sha256:70b4a4984a70a2322b70e088d654528129783ac1ebbf7dd76627b3bd22db2f17"}, + {file = "aiohttp-3.10.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:166de65e2e4e63357cfa8417cf952a519ac42f1654cb2d43ed76899e2319b1ee"}, + {file = "aiohttp-3.10.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7084876352ba3833d5d214e02b32d794e3fd9cf21fdba99cff5acabeb90d9806"}, + {file = "aiohttp-3.10.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d98c604c93403288591d7d6d7d6cc8a63459168f8846aeffd5b3a7f3b3e5e09"}, + {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d73b073a25a0bb8bf014345374fe2d0f63681ab5da4c22f9d2025ca3e3ea54fc"}, + {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8da6b48c20ce78f5721068f383e0e113dde034e868f1b2f5ee7cb1e95f91db57"}, + {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3a9dcdccf50284b1b0dc72bc57e5bbd3cc9bf019060dfa0668f63241ccc16aa7"}, + {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56fb94bae2be58f68d000d046172d8b8e6b1b571eb02ceee5535e9633dcd559c"}, + {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bf75716377aad2c718cdf66451c5cf02042085d84522aec1f9246d3e4b8641a6"}, + {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6c51ed03e19c885c8e91f574e4bbe7381793f56f93229731597e4a499ffef2a5"}, + {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b84857b66fa6510a163bb083c1199d1ee091a40163cfcbbd0642495fed096204"}, + {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c124b9206b1befe0491f48185fd30a0dd51b0f4e0e7e43ac1236066215aff272"}, + {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3461d9294941937f07bbbaa6227ba799bc71cc3b22c40222568dc1cca5118f68"}, + {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:08bd0754d257b2db27d6bab208c74601df6f21bfe4cb2ec7b258ba691aac64b3"}, + {file = "aiohttp-3.10.3-cp312-cp312-win32.whl", hash = "sha256:7f9159ae530297f61a00116771e57516f89a3de6ba33f314402e41560872b50a"}, + {file = "aiohttp-3.10.3-cp312-cp312-win_amd64.whl", hash = "sha256:e1128c5d3a466279cb23c4aa32a0f6cb0e7d2961e74e9e421f90e74f75ec1edf"}, + {file = "aiohttp-3.10.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:d1100e68e70eb72eadba2b932b185ebf0f28fd2f0dbfe576cfa9d9894ef49752"}, + {file = "aiohttp-3.10.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a541414578ff47c0a9b0b8b77381ea86b0c8531ab37fc587572cb662ccd80b88"}, + {file = "aiohttp-3.10.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d5548444ef60bf4c7b19ace21f032fa42d822e516a6940d36579f7bfa8513f9c"}, + {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ba2e838b5e6a8755ac8297275c9460e729dc1522b6454aee1766c6de6d56e5e"}, + {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:48665433bb59144aaf502c324694bec25867eb6630fcd831f7a893ca473fcde4"}, + {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bac352fceed158620ce2d701ad39d4c1c76d114255a7c530e057e2b9f55bdf9f"}, + {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b0f670502100cdc567188c49415bebba947eb3edaa2028e1a50dd81bd13363f"}, + {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43b09f38a67679e32d380fe512189ccb0b25e15afc79b23fbd5b5e48e4fc8fd9"}, + {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:cd788602e239ace64f257d1c9d39898ca65525583f0fbf0988bcba19418fe93f"}, + {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:214277dcb07ab3875f17ee1c777d446dcce75bea85846849cc9d139ab8f5081f"}, + {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:32007fdcaab789689c2ecaaf4b71f8e37bf012a15cd02c0a9db8c4d0e7989fa8"}, + {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:123e5819bfe1b87204575515cf448ab3bf1489cdeb3b61012bde716cda5853e7"}, + {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:812121a201f0c02491a5db335a737b4113151926a79ae9ed1a9f41ea225c0e3f"}, + {file = "aiohttp-3.10.3-cp38-cp38-win32.whl", hash = "sha256:b97dc9a17a59f350c0caa453a3cb35671a2ffa3a29a6ef3568b523b9113d84e5"}, + {file = "aiohttp-3.10.3-cp38-cp38-win_amd64.whl", hash = "sha256:3731a73ddc26969d65f90471c635abd4e1546a25299b687e654ea6d2fc052394"}, + {file = "aiohttp-3.10.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38d91b98b4320ffe66efa56cb0f614a05af53b675ce1b8607cdb2ac826a8d58e"}, + {file = "aiohttp-3.10.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9743fa34a10a36ddd448bba8a3adc2a66a1c575c3c2940301bacd6cc896c6bf1"}, + {file = "aiohttp-3.10.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7c126f532caf238031c19d169cfae3c6a59129452c990a6e84d6e7b198a001dc"}, + {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:926e68438f05703e500b06fe7148ef3013dd6f276de65c68558fa9974eeb59ad"}, + {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:434b3ab75833accd0b931d11874e206e816f6e6626fd69f643d6a8269cd9166a"}, + {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d35235a44ec38109b811c3600d15d8383297a8fab8e3dec6147477ec8636712a"}, + {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59c489661edbd863edb30a8bd69ecb044bd381d1818022bc698ba1b6f80e5dd1"}, + {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50544fe498c81cb98912afabfc4e4d9d85e89f86238348e3712f7ca6a2f01dab"}, + {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:09bc79275737d4dc066e0ae2951866bb36d9c6b460cb7564f111cc0427f14844"}, + {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:af4dbec58e37f5afff4f91cdf235e8e4b0bd0127a2a4fd1040e2cad3369d2f06"}, + {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:b22cae3c9dd55a6b4c48c63081d31c00fc11fa9db1a20c8a50ee38c1a29539d2"}, + {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ba562736d3fbfe9241dad46c1a8994478d4a0e50796d80e29d50cabe8fbfcc3f"}, + {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:f25d6c4e82d7489be84f2b1c8212fafc021b3731abdb61a563c90e37cced3a21"}, + {file = "aiohttp-3.10.3-cp39-cp39-win32.whl", hash = "sha256:b69d832e5f5fa15b1b6b2c8eb6a9fd2c0ec1fd7729cb4322ed27771afc9fc2ac"}, + {file = "aiohttp-3.10.3-cp39-cp39-win_amd64.whl", hash = "sha256:673bb6e3249dc8825df1105f6ef74e2eab779b7ff78e96c15cadb78b04a83752"}, + {file = "aiohttp-3.10.3.tar.gz", hash = "sha256:21650e7032cc2d31fc23d353d7123e771354f2a3d5b05a5647fc30fea214e696"}, +] + +[package.dependencies] +aiohappyeyeballs = ">=2.3.0" +aiosignal = ">=1.1.2" +async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} +attrs = ">=17.3.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +yarl = ">=1.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] + [[package]] name = "aiolimiter" version = "1.1.0" @@ -22,6 +142,20 @@ files = [ {file = "aiolimiter-1.1.0.tar.gz", hash = "sha256:461cf02f82a29347340d031626c92853645c099cb5ff85577b831a7bd21132b5"}, ] +[[package]] +name = "aiosignal" +version = "1.3.1" +description = "aiosignal: a list of registered asynchronous callbacks" +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, + {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, +] + +[package.dependencies] +frozenlist = ">=1.1.0" + [[package]] name = "annotated-types" version = "0.7.0" @@ -188,6 +322,17 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, +] + [[package]] name = "attrs" version = "24.1.0" @@ -1204,6 +1349,7 @@ files = [ {file = "fastparquet-2024.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5626fc72204001b7e82fedb4b02174ecb4e2d4143b38b4ea8d2f9eb65f6b000e"}, {file = "fastparquet-2024.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c8b2e86fe6488cce0e3d41263bb0296ef9bbb875a2fca09d67d7685640017a66"}, {file = "fastparquet-2024.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2a951106782d51e5ab110beaad29c4aa0537f045711bb0bf146f65aeaed14174"}, + {file = "fastparquet-2024.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:cd3473d3e299bfb04c0ac7726cca5d13ee450cc2387ee7fd70587ca150647315"}, {file = "fastparquet-2024.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:47695037fdc534ef4247f25ccf17dcbd8825be6ecb70c54ca54d588a794f4a6d"}, {file = "fastparquet-2024.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc3d35ff8341cd65baecac71062e9d73393d7afda207b3421709c1d3f4baa194"}, {file = "fastparquet-2024.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:691348cc85890663dd3c0bb02544d38d4c07a0c3d68837324dc01007301150b5"}, @@ -1301,6 +1447,92 @@ files = [ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] +[[package]] +name = "frozenlist" +version = "1.4.1" +description = "A list-like structure which implements collections.abc.MutableSequence" +optional = false +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, + {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, + {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, + {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, + {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, + {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, + {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, + {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, + {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, + {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, + {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, + {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, + {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, +] + [[package]] name = "fsspec" version = "2024.6.1" @@ -1439,6 +1671,28 @@ files = [ {file = "graspologic_native-1.2.1.tar.gz", hash = "sha256:72b7586028a91e9fef9af0ef314d368f0240c18dca99e6e6c546334359a8610a"}, ] +[[package]] +name = "gremlinpython" +version = "3.7.2" +description = "Gremlin-Python for Apache TinkerPop" +optional = false +python-versions = "*" +files = [ + {file = "gremlinpython-3.7.2-py2.py3-none-any.whl", hash = "sha256:ff925632f70fc5afff6864627fa69de8df577090a6eea63f10f0fed7194f98bf"}, + {file = "gremlinpython-3.7.2.tar.gz", hash = "sha256:58d12e4af81210d7770d54da8f1f586de816bc00858e095f76326d338ea34acd"}, +] + +[package.dependencies] +aenum = ">=1.4.5,<4.0.0" +aiohttp = ">=3.8.0,<4.0.0" +isodate = ">=0.6.0,<1.0.0" +nest-asyncio = "*" +six = ">=1.10.0,<2.0.0" + +[package.extras] +kerberos = ["kerberos (>=1.3.0,<2.0.0)"] +ujson = ["ujson (>=2.0.0)"] + [[package]] name = "h11" version = "0.14.0" @@ -2522,6 +2776,105 @@ files = [ msal = ">=1.29,<2" portalocker = ">=1.4,<3" +[[package]] +name = "multidict" +version = "6.0.5" +description = "multidict implementation" +optional = false +python-versions = ">=3.7" +files = [ + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, + {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, + {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, + {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, + {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, + {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, + {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, + {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, + {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, + {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, + {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, + {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, + {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, + {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, + {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, + {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, +] + [[package]] name = "nbclient" version = "0.10.0" @@ -3786,6 +4139,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -5143,6 +5497,109 @@ files = [ {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, ] +[[package]] +name = "yarl" +version = "1.9.4" +description = "Yet another URL library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, + {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, + {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, + {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, + {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, + {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, + {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, + {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, + {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, + {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, + {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, + {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, + {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, + {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, + {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, + {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" + [[package]] name = "zipp" version = "3.19.2" @@ -5161,4 +5618,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "3b42a857248dd22d200eb65786cca8e7ce0378fb603083c62a6d11a8f47f8b45" +content-hash = "b2c972b8acfe25a9da7a68b3c2694a431908d66738b8082d48bfe124a78585b9" diff --git a/pyproject.toml b/pyproject.toml index 90af801003..6978359eb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ typing-extensions = "^4.12.2" azure-storage-blob = "^12.19.0" azure-identity = "^1.17.1" json-repair = "^0.25.3" +gremlinpython = "^3.7.2" [tool.poetry.group.dev.dependencies] coverage = "^7.6.0" From d79fa7ba0e4972943626cb6e6628a5d2a58676ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Wed, 21 Aug 2024 10:41:17 -0700 Subject: [PATCH 20/87] Add Reading from graphdb --- common/graph_db_client.py | 129 +++++++++++++++++++++++ graphrag/index/create_pipeline_config.py | 1 + graphrag/index/emit/graph_db_emitter.py | 28 +++++ graphrag/query/cli.py | 14 ++- 4 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 common/graph_db_client.py create mode 100644 graphrag/index/emit/graph_db_emitter.py diff --git a/common/graph_db_client.py b/common/graph_db_client.py new file mode 100644 index 0000000000..31f21c813f --- /dev/null +++ b/common/graph_db_client.py @@ -0,0 +1,129 @@ +import os +import pandas as pd + +import numpy as np + +import ast + +from gremlin_python.driver import client, serializer + +import time +import os +import json + +class GraphDBClient: + def __init__(self): + ACCOUNT_NAME = os.getenv("ACCOUNT_NAME") + ACCOUNT_KEY = os.getenv("ACCOUNT_KEY") + GRAPHDB_USERNAME = os.getenv("GRAPHDB_USERNAME") + self._client=client.Client( + url=f"wss://{ACCOUNT_NAME}.gremlin.cosmos.azure.com:443/", + traversal_source="g", + username=GRAPHDB_USERNAME, + password=f"{ACCOUNT_KEY}", + message_serializer=serializer.GraphSONSerializersV2d0(), + ) + + def result_to_df(self,result) -> pd.DataFrame: + json_data = [] + for row in result: + json_row = row[0] + properties_dict = json_row.pop('properties') + formatted_properties={} + for k,v in properties_dict.items(): + new_val=v + if isinstance(v,list) and isinstance(v[0],dict): + new_val=v[0]['value'] + if k=='description_embedding' or k =='text_unit_ids' or k=='graph_embedding': + new_val=ast.literal_eval(new_val) + if isinstance(new_val,list): + new_val=np.array(new_val) + formatted_properties[k]=new_val + json_row.update(formatted_properties) + json_data.append(json_row) + df = pd.DataFrame(json_data) + return df + + def query_vertices(self) -> pd.DataFrame: + result = self._client.submit( + message=( + "g.V()" + ), + ) + return self.result_to_df(result) + + def query_edges(self) -> pd.DataFrame: + result = self._client.submit( + message=( + "g.E()" + ), + ) + return self.result_to_df(result) + + def write_vertices(self,data: pd.DataFrame)->None: + for row in data.itertuples(): + print(row.id) + self._client.submit( + message=( + "g.addV('entity')" + ".property('id', prop_id)" + ".property('name', prop_name)" + ".property('type', prop_type)" + ".property('description','prop_description')" + ".property('human_readable_id', prop_human_readable_id)" + ".property('category', prop_partition_key)" + ".property(list,'description_embedding',prop_description_embedding)" + ".property(list,'graph_embedding',prop_graph_embedding)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ), + bindings={ + "prop_id": row.id, + "prop_name": row.name, + "prop_type": row.type, + "prop_description": row.description, + "prop_human_readable_id": row.human_readable_id, + "prop_partition_key": "entities", + "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), + "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), + "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), + }, + ) + time.sleep(5) + + + def write_edges(self,data: pd.DataFrame)->None: + for row in data.itertuples(): + print(row.source,row.target) + self._client.submit( + message=( + "g.V().has('name',prop_source_id)" + ".addE('connects')" + ".to(g.V().has('name',prop_target_id))" + ".property('weight',prop_weight)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ".property('description',prop_description)" + ".property('id',prop_id)" + ".property('human_readable_id',prop_human_readable_id)" + ".property('source_degree',prop_source_degree)" + ".property('target_degree',prop_target_degree)" + ".property('rank',prop_rank)" + ".property('source',prop_source)" + ".property('target',prop_target)" + ), + bindings={ + "prop_partition_key": "entities", + "prop_source_id": row.source, + "prop_target_id": row.target, + "prop_weight": row.weight, + "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), + "prop_description": row.description, + "prop_id": row.id, + "prop_human_readable_id": row.human_readable_id, + "prop_source_degree": row.source_degree, + "prop_target_degree": row.target_degree, + "prop_rank": row.rank, + "prop_source": row.source, + "prop_target": row.target, + }, + ) + time.sleep(5) \ No newline at end of file diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index 4472ef8828..439f64cd75 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -370,6 +370,7 @@ def _graph_workflows( }, ), "skip_description_embedding": skip_relationship_description_embedding, + "emitter_type": TableEmitterType.Graphdb, }, ), PipelineWorkflowReference( diff --git a/graphrag/index/emit/graph_db_emitter.py b/graphrag/index/emit/graph_db_emitter.py new file mode 100644 index 0000000000..aa8ed7cfb7 --- /dev/null +++ b/graphrag/index/emit/graph_db_emitter.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphDBEmitter module.""" + +import logging +import traceback + +import pandas as pd + +from gremlin_python.driver import client, serializer + +from .table_emitter import TableEmitter + +from common.graph_db_client import GraphDBClient + +class GraphDBEmitter(TableEmitter): + def __init__(self): + self.graph_db_client = GraphDBClient() + self.allowed_workflows = ['create_final_entities','create_final_relationships'] + + async def emit(self, name: str, data: pd.DataFrame) -> None: + if name not in self.allowed_workflows: + return + if name == 'create_final_entities': + self.graph_db_client.write_vertices(data) + if name == 'create_final_relationships': + self.graph_db_client.write_edges(data) \ No newline at end of file diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 2128aff077..c1574ba06d 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -38,6 +38,8 @@ read_indexer_text_units, ) +from common.graph_db_client import GraphDBClient + reporter = PrintProgressReporter("") def __get_embedding_description_store( @@ -100,14 +102,13 @@ def run_global_search( data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config_dir ) + graph_db_client = GraphDBClient() data_path = Path(data_dir) final_nodes: pd.DataFrame = pd.read_parquet( data_path / "create_final_nodes.parquet" ) - final_entities: pd.DataFrame = pd.read_parquet( - data_path / "create_final_entities.parquet" - ) + final_entities = graph_db_client.query_vertices() final_community_reports: pd.DataFrame = pd.read_parquet( data_path / "create_final_community_reports.parquet" ) @@ -154,6 +155,7 @@ def run_local_search( final_relationships = pd.DataFrame() final_entities = pd.DataFrame() final_covariates = pd.DataFrame() + graph_db_client = GraphDBClient() for data_path in data_paths: #check from the config for the ouptut storage type and then read the data from the storage. @@ -164,10 +166,12 @@ def run_local_search( final_text_units = pd.concat([final_text_units, read_paraquet_file(config, data_path + "/create_final_text_units.parquet", config.storage.type)]) - final_relationships = pd.concat([final_relationships, read_paraquet_file(config, data_path + "/create_final_relationships.parquet", config.storage.type)]) + #final_relationships = pd.concat([final_relationships, read_paraquet_file(config, data_path + "/create_final_relationships.parquet", config.storage.type)]) + final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) - final_entities = pd.concat([final_entities, read_paraquet_file(config, data_path + "/create_final_entities.parquet", config.storage.type)]) + #final_entities = pd.concat([final_entities, read_paraquet_file(config, data_path + "/create_final_entities.parquet", config.storage.type)]) + final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) final_covariates = pd.concat([final_covariates, ( read_paraquet_file(config, data_path + "/create_final_covariates.parquet", config.storage.type) From 3d2be1cf18a55c6089ab28bf7dcc10e14a8b1ff3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Wed, 21 Aug 2024 16:44:56 -0700 Subject: [PATCH 21/87] Integrating with latest PR --- graphrag/query/cli.py | 5 +- poetry.lock | 1296 ++++++++++++++++++++++++----------------- 2 files changed, 772 insertions(+), 529 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index c1574ba06d..d6d8cc7a88 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -173,8 +173,11 @@ def run_local_search( #final_entities = pd.concat([final_entities, read_paraquet_file(config, data_path + "/create_final_entities.parquet", config.storage.type)]) final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) + data_path_object = Path(data_path) + final_covariates_path = data_path_object / "create_final_covariates.parquet" + final_covariates = pd.concat([final_covariates, ( - read_paraquet_file(config, data_path + "/create_final_covariates.parquet", config.storage.type) + read_paraquet_file(config, final_covariates_path, config.storage.type) if final_covariates_path.exists() else None )]) vector_store_args = ( diff --git a/poetry.lock b/poetry.lock index 49b84c97e0..fbacd85265 100644 --- a/poetry.lock +++ b/poetry.lock @@ -25,98 +25,113 @@ files = [ [[package]] name = "aiohappyeyeballs" -version = "2.3.7" +version = "2.4.0" description = "Happy Eyeballs for asyncio" optional = false python-versions = ">=3.8" files = [ - {file = "aiohappyeyeballs-2.3.7-py3-none-any.whl", hash = "sha256:337ce4dc0e99eb697c3c5a77d6cb3c52925824d9a67ac0dea7c55b8a2d60b222"}, - {file = "aiohappyeyeballs-2.3.7.tar.gz", hash = "sha256:e794cd29ba6a14078092984e43688212a19081de3a73b6796c2fdeb3706dd6ce"}, + {file = "aiohappyeyeballs-2.4.0-py3-none-any.whl", hash = "sha256:7ce92076e249169a13c2f49320d1967425eaf1f407522d707d59cac7628d62bd"}, + {file = "aiohappyeyeballs-2.4.0.tar.gz", hash = "sha256:55a1714f084e63d49639800f95716da97a1f173d46a16dfcfda0016abb93b6b2"}, ] [[package]] name = "aiohttp" -version = "3.10.3" +version = "3.10.5" description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.8" files = [ - {file = "aiohttp-3.10.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc36cbdedf6f259371dbbbcaae5bb0e95b879bc501668ab6306af867577eb5db"}, - {file = "aiohttp-3.10.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:85466b5a695c2a7db13eb2c200af552d13e6a9313d7fa92e4ffe04a2c0ea74c1"}, - {file = "aiohttp-3.10.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:71bb1d97bfe7e6726267cea169fdf5df7658831bb68ec02c9c6b9f3511e108bb"}, - {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baec1eb274f78b2de54471fc4c69ecbea4275965eab4b556ef7a7698dee18bf2"}, - {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13031e7ec1188274bad243255c328cc3019e36a5a907978501256000d57a7201"}, - {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2bbc55a964b8eecb341e492ae91c3bd0848324d313e1e71a27e3d96e6ee7e8e8"}, - {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8cc0564b286b625e673a2615ede60a1704d0cbbf1b24604e28c31ed37dc62aa"}, - {file = "aiohttp-3.10.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f817a54059a4cfbc385a7f51696359c642088710e731e8df80d0607193ed2b73"}, - {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8542c9e5bcb2bd3115acdf5adc41cda394e7360916197805e7e32b93d821ef93"}, - {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:671efce3a4a0281060edf9a07a2f7e6230dca3a1cbc61d110eee7753d28405f7"}, - {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:0974f3b5b0132edcec92c3306f858ad4356a63d26b18021d859c9927616ebf27"}, - {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:44bb159b55926b57812dca1b21c34528e800963ffe130d08b049b2d6b994ada7"}, - {file = "aiohttp-3.10.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6ae9ae382d1c9617a91647575255ad55a48bfdde34cc2185dd558ce476bf16e9"}, - {file = "aiohttp-3.10.3-cp310-cp310-win32.whl", hash = "sha256:aed12a54d4e1ee647376fa541e1b7621505001f9f939debf51397b9329fd88b9"}, - {file = "aiohttp-3.10.3-cp310-cp310-win_amd64.whl", hash = "sha256:b51aef59370baf7444de1572f7830f59ddbabd04e5292fa4218d02f085f8d299"}, - {file = "aiohttp-3.10.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e021c4c778644e8cdc09487d65564265e6b149896a17d7c0f52e9a088cc44e1b"}, - {file = "aiohttp-3.10.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:24fade6dae446b183e2410a8628b80df9b7a42205c6bfc2eff783cbeedc224a2"}, - {file = "aiohttp-3.10.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bc8e9f15939dacb0e1f2d15f9c41b786051c10472c7a926f5771e99b49a5957f"}, - {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5a9ec959b5381271c8ec9310aae1713b2aec29efa32e232e5ef7dcca0df0279"}, - {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a5d0ea8a6467b15d53b00c4e8ea8811e47c3cc1bdbc62b1aceb3076403d551f"}, - {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c9ed607dbbdd0d4d39b597e5bf6b0d40d844dfb0ac6a123ed79042ef08c1f87e"}, - {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3e66d5b506832e56add66af88c288c1d5ba0c38b535a1a59e436b300b57b23e"}, - {file = "aiohttp-3.10.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fda91ad797e4914cca0afa8b6cccd5d2b3569ccc88731be202f6adce39503189"}, - {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:61ccb867b2f2f53df6598eb2a93329b5eee0b00646ee79ea67d68844747a418e"}, - {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:6d881353264e6156f215b3cb778c9ac3184f5465c2ece5e6fce82e68946868ef"}, - {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:b031ce229114825f49cec4434fa844ccb5225e266c3e146cb4bdd025a6da52f1"}, - {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5337cc742a03f9e3213b097abff8781f79de7190bbfaa987bd2b7ceb5bb0bdec"}, - {file = "aiohttp-3.10.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ab3361159fd3dcd0e48bbe804006d5cfb074b382666e6c064112056eb234f1a9"}, - {file = "aiohttp-3.10.3-cp311-cp311-win32.whl", hash = "sha256:05d66203a530209cbe40f102ebaac0b2214aba2a33c075d0bf825987c36f1f0b"}, - {file = "aiohttp-3.10.3-cp311-cp311-win_amd64.whl", hash = "sha256:70b4a4984a70a2322b70e088d654528129783ac1ebbf7dd76627b3bd22db2f17"}, - {file = "aiohttp-3.10.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:166de65e2e4e63357cfa8417cf952a519ac42f1654cb2d43ed76899e2319b1ee"}, - {file = "aiohttp-3.10.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7084876352ba3833d5d214e02b32d794e3fd9cf21fdba99cff5acabeb90d9806"}, - {file = "aiohttp-3.10.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d98c604c93403288591d7d6d7d6cc8a63459168f8846aeffd5b3a7f3b3e5e09"}, - {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d73b073a25a0bb8bf014345374fe2d0f63681ab5da4c22f9d2025ca3e3ea54fc"}, - {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8da6b48c20ce78f5721068f383e0e113dde034e868f1b2f5ee7cb1e95f91db57"}, - {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3a9dcdccf50284b1b0dc72bc57e5bbd3cc9bf019060dfa0668f63241ccc16aa7"}, - {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56fb94bae2be58f68d000d046172d8b8e6b1b571eb02ceee5535e9633dcd559c"}, - {file = "aiohttp-3.10.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bf75716377aad2c718cdf66451c5cf02042085d84522aec1f9246d3e4b8641a6"}, - {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6c51ed03e19c885c8e91f574e4bbe7381793f56f93229731597e4a499ffef2a5"}, - {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b84857b66fa6510a163bb083c1199d1ee091a40163cfcbbd0642495fed096204"}, - {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c124b9206b1befe0491f48185fd30a0dd51b0f4e0e7e43ac1236066215aff272"}, - {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3461d9294941937f07bbbaa6227ba799bc71cc3b22c40222568dc1cca5118f68"}, - {file = "aiohttp-3.10.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:08bd0754d257b2db27d6bab208c74601df6f21bfe4cb2ec7b258ba691aac64b3"}, - {file = "aiohttp-3.10.3-cp312-cp312-win32.whl", hash = "sha256:7f9159ae530297f61a00116771e57516f89a3de6ba33f314402e41560872b50a"}, - {file = "aiohttp-3.10.3-cp312-cp312-win_amd64.whl", hash = "sha256:e1128c5d3a466279cb23c4aa32a0f6cb0e7d2961e74e9e421f90e74f75ec1edf"}, - {file = "aiohttp-3.10.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:d1100e68e70eb72eadba2b932b185ebf0f28fd2f0dbfe576cfa9d9894ef49752"}, - {file = "aiohttp-3.10.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a541414578ff47c0a9b0b8b77381ea86b0c8531ab37fc587572cb662ccd80b88"}, - {file = "aiohttp-3.10.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d5548444ef60bf4c7b19ace21f032fa42d822e516a6940d36579f7bfa8513f9c"}, - {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ba2e838b5e6a8755ac8297275c9460e729dc1522b6454aee1766c6de6d56e5e"}, - {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:48665433bb59144aaf502c324694bec25867eb6630fcd831f7a893ca473fcde4"}, - {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bac352fceed158620ce2d701ad39d4c1c76d114255a7c530e057e2b9f55bdf9f"}, - {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b0f670502100cdc567188c49415bebba947eb3edaa2028e1a50dd81bd13363f"}, - {file = "aiohttp-3.10.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43b09f38a67679e32d380fe512189ccb0b25e15afc79b23fbd5b5e48e4fc8fd9"}, - {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:cd788602e239ace64f257d1c9d39898ca65525583f0fbf0988bcba19418fe93f"}, - {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:214277dcb07ab3875f17ee1c777d446dcce75bea85846849cc9d139ab8f5081f"}, - {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:32007fdcaab789689c2ecaaf4b71f8e37bf012a15cd02c0a9db8c4d0e7989fa8"}, - {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:123e5819bfe1b87204575515cf448ab3bf1489cdeb3b61012bde716cda5853e7"}, - {file = "aiohttp-3.10.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:812121a201f0c02491a5db335a737b4113151926a79ae9ed1a9f41ea225c0e3f"}, - {file = "aiohttp-3.10.3-cp38-cp38-win32.whl", hash = "sha256:b97dc9a17a59f350c0caa453a3cb35671a2ffa3a29a6ef3568b523b9113d84e5"}, - {file = "aiohttp-3.10.3-cp38-cp38-win_amd64.whl", hash = "sha256:3731a73ddc26969d65f90471c635abd4e1546a25299b687e654ea6d2fc052394"}, - {file = "aiohttp-3.10.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38d91b98b4320ffe66efa56cb0f614a05af53b675ce1b8607cdb2ac826a8d58e"}, - {file = "aiohttp-3.10.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9743fa34a10a36ddd448bba8a3adc2a66a1c575c3c2940301bacd6cc896c6bf1"}, - {file = "aiohttp-3.10.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7c126f532caf238031c19d169cfae3c6a59129452c990a6e84d6e7b198a001dc"}, - {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:926e68438f05703e500b06fe7148ef3013dd6f276de65c68558fa9974eeb59ad"}, - {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:434b3ab75833accd0b931d11874e206e816f6e6626fd69f643d6a8269cd9166a"}, - {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d35235a44ec38109b811c3600d15d8383297a8fab8e3dec6147477ec8636712a"}, - {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59c489661edbd863edb30a8bd69ecb044bd381d1818022bc698ba1b6f80e5dd1"}, - {file = "aiohttp-3.10.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50544fe498c81cb98912afabfc4e4d9d85e89f86238348e3712f7ca6a2f01dab"}, - {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:09bc79275737d4dc066e0ae2951866bb36d9c6b460cb7564f111cc0427f14844"}, - {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:af4dbec58e37f5afff4f91cdf235e8e4b0bd0127a2a4fd1040e2cad3369d2f06"}, - {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:b22cae3c9dd55a6b4c48c63081d31c00fc11fa9db1a20c8a50ee38c1a29539d2"}, - {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ba562736d3fbfe9241dad46c1a8994478d4a0e50796d80e29d50cabe8fbfcc3f"}, - {file = "aiohttp-3.10.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:f25d6c4e82d7489be84f2b1c8212fafc021b3731abdb61a563c90e37cced3a21"}, - {file = "aiohttp-3.10.3-cp39-cp39-win32.whl", hash = "sha256:b69d832e5f5fa15b1b6b2c8eb6a9fd2c0ec1fd7729cb4322ed27771afc9fc2ac"}, - {file = "aiohttp-3.10.3-cp39-cp39-win_amd64.whl", hash = "sha256:673bb6e3249dc8825df1105f6ef74e2eab779b7ff78e96c15cadb78b04a83752"}, - {file = "aiohttp-3.10.3.tar.gz", hash = "sha256:21650e7032cc2d31fc23d353d7123e771354f2a3d5b05a5647fc30fea214e696"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:18a01eba2574fb9edd5f6e5fb25f66e6ce061da5dab5db75e13fe1558142e0a3"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:94fac7c6e77ccb1ca91e9eb4cb0ac0270b9fb9b289738654120ba8cebb1189c6"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2f1f1c75c395991ce9c94d3e4aa96e5c59c8356a15b1c9231e783865e2772699"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7acae3cf1a2a2361ec4c8e787eaaa86a94171d2417aae53c0cca6ca3118ff6"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:94c4381ffba9cc508b37d2e536b418d5ea9cfdc2848b9a7fea6aebad4ec6aac1"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c31ad0c0c507894e3eaa843415841995bf8de4d6b2d24c6e33099f4bc9fc0d4f"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0912b8a8fadeb32ff67a3ed44249448c20148397c1ed905d5dac185b4ca547bb"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d93400c18596b7dc4794d48a63fb361b01a0d8eb39f28800dc900c8fbdaca91"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d00f3c5e0d764a5c9aa5a62d99728c56d455310bcc288a79cab10157b3af426f"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:d742c36ed44f2798c8d3f4bc511f479b9ceef2b93f348671184139e7d708042c"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:814375093edae5f1cb31e3407997cf3eacefb9010f96df10d64829362ae2df69"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8224f98be68a84b19f48e0bdc14224b5a71339aff3a27df69989fa47d01296f3"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d9a487ef090aea982d748b1b0d74fe7c3950b109df967630a20584f9a99c0683"}, + {file = "aiohttp-3.10.5-cp310-cp310-win32.whl", hash = "sha256:d9ef084e3dc690ad50137cc05831c52b6ca428096e6deb3c43e95827f531d5ef"}, + {file = "aiohttp-3.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:66bf9234e08fe561dccd62083bf67400bdbf1c67ba9efdc3dac03650e97c6088"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8c6a4e5e40156d72a40241a25cc226051c0a8d816610097a8e8f517aeacd59a2"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c634a3207a5445be65536d38c13791904fda0748b9eabf908d3fe86a52941cf"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4aff049b5e629ef9b3e9e617fa6e2dfeda1bf87e01bcfecaf3949af9e210105e"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1942244f00baaacaa8155eca94dbd9e8cc7017deb69b75ef67c78e89fdad3c77"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e04a1f2a65ad2f93aa20f9ff9f1b672bf912413e5547f60749fa2ef8a644e061"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f2bfc0032a00405d4af2ba27f3c429e851d04fad1e5ceee4080a1c570476697"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:424ae21498790e12eb759040bbb504e5e280cab64693d14775c54269fd1d2bb7"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:975218eee0e6d24eb336d0328c768ebc5d617609affaca5dbbd6dd1984f16ed0"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4120d7fefa1e2d8fb6f650b11489710091788de554e2b6f8347c7a20ceb003f5"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b90078989ef3fc45cf9221d3859acd1108af7560c52397ff4ace8ad7052a132e"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ba5a8b74c2a8af7d862399cdedce1533642fa727def0b8c3e3e02fcb52dca1b1"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:02594361128f780eecc2a29939d9dfc870e17b45178a867bf61a11b2a4367277"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8fb4fc029e135859f533025bc82047334e24b0d489e75513144f25408ecaf058"}, + {file = "aiohttp-3.10.5-cp311-cp311-win32.whl", hash = "sha256:e1ca1ef5ba129718a8fc827b0867f6aa4e893c56eb00003b7367f8a733a9b072"}, + {file = "aiohttp-3.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:349ef8a73a7c5665cca65c88ab24abe75447e28aa3bc4c93ea5093474dfdf0ff"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:305be5ff2081fa1d283a76113b8df7a14c10d75602a38d9f012935df20731487"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3a1c32a19ee6bbde02f1cb189e13a71b321256cc1d431196a9f824050b160d5a"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:61645818edd40cc6f455b851277a21bf420ce347baa0b86eaa41d51ef58ba23d"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c225286f2b13bab5987425558baa5cbdb2bc925b2998038fa028245ef421e75"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ba01ebc6175e1e6b7275c907a3a36be48a2d487549b656aa90c8a910d9f3178"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8eaf44ccbc4e35762683078b72bf293f476561d8b68ec8a64f98cf32811c323e"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c43eb1ab7cbf411b8e387dc169acb31f0ca0d8c09ba63f9eac67829585b44f"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de7a5299827253023c55ea549444e058c0eb496931fa05d693b95140a947cb73"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4790f0e15f00058f7599dab2b206d3049d7ac464dc2e5eae0e93fa18aee9e7bf"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:44b324a6b8376a23e6ba25d368726ee3bc281e6ab306db80b5819999c737d820"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0d277cfb304118079e7044aad0b76685d30ecb86f83a0711fc5fb257ffe832ca"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:54d9ddea424cd19d3ff6128601a4a4d23d54a421f9b4c0fff740505813739a91"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4f1c9866ccf48a6df2b06823e6ae80573529f2af3a0992ec4fe75b1a510df8a6"}, + {file = "aiohttp-3.10.5-cp312-cp312-win32.whl", hash = "sha256:dc4826823121783dccc0871e3f405417ac116055bf184ac04c36f98b75aacd12"}, + {file = "aiohttp-3.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:22c0a23a3b3138a6bf76fc553789cb1a703836da86b0f306b6f0dc1617398abc"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7f6b639c36734eaa80a6c152a238242bedcee9b953f23bb887e9102976343092"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f29930bc2921cef955ba39a3ff87d2c4398a0394ae217f41cb02d5c26c8b1b77"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f489a2c9e6455d87eabf907ac0b7d230a9786be43fbe884ad184ddf9e9c1e385"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:123dd5b16b75b2962d0fff566effb7a065e33cd4538c1692fb31c3bda2bfb972"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b98e698dc34966e5976e10bbca6d26d6724e6bdea853c7c10162a3235aba6e16"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3b9162bab7e42f21243effc822652dc5bb5e8ff42a4eb62fe7782bcbcdfacf6"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1923a5c44061bffd5eebeef58cecf68096e35003907d8201a4d0d6f6e387ccaa"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55f011da0a843c3d3df2c2cf4e537b8070a419f891c930245f05d329c4b0689"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:afe16a84498441d05e9189a15900640a2d2b5e76cf4efe8cbb088ab4f112ee57"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8112fb501b1e0567a1251a2fd0747baae60a4ab325a871e975b7bb67e59221f"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1e72589da4c90337837fdfe2026ae1952c0f4a6e793adbbfbdd40efed7c63599"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4d46c7b4173415d8e583045fbc4daa48b40e31b19ce595b8d92cf639396c15d5"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:33e6bc4bab477c772a541f76cd91e11ccb6d2efa2b8d7d7883591dfb523e5987"}, + {file = "aiohttp-3.10.5-cp313-cp313-win32.whl", hash = "sha256:c58c6837a2c2a7cf3133983e64173aec11f9c2cd8e87ec2fdc16ce727bcf1a04"}, + {file = "aiohttp-3.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:38172a70005252b6893088c0f5e8a47d173df7cc2b2bd88650957eb84fcf5022"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f6f18898ace4bcd2d41a122916475344a87f1dfdec626ecde9ee802a711bc569"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5ede29d91a40ba22ac1b922ef510aab871652f6c88ef60b9dcdf773c6d32ad7a"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:673f988370f5954df96cc31fd99c7312a3af0a97f09e407399f61583f30da9bc"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58718e181c56a3c02d25b09d4115eb02aafe1a732ce5714ab70326d9776457c3"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b38b1570242fbab8d86a84128fb5b5234a2f70c2e32f3070143a6d94bc854cf"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:074d1bff0163e107e97bd48cad9f928fa5a3eb4b9d33366137ffce08a63e37fe"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd31f176429cecbc1ba499d4aba31aaccfea488f418d60376b911269d3b883c5"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7384d0b87d4635ec38db9263e6a3f1eb609e2e06087f0aa7f63b76833737b471"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8989f46f3d7ef79585e98fa991e6ded55d2f48ae56d2c9fa5e491a6e4effb589"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c83f7a107abb89a227d6c454c613e7606c12a42b9a4ca9c5d7dad25d47c776ae"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:cde98f323d6bf161041e7627a5fd763f9fd829bcfcd089804a5fdce7bb6e1b7d"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:676f94c5480d8eefd97c0c7e3953315e4d8c2b71f3b49539beb2aa676c58272f"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2d21ac12dc943c68135ff858c3a989f2194a709e6e10b4c8977d7fcd67dfd511"}, + {file = "aiohttp-3.10.5-cp38-cp38-win32.whl", hash = "sha256:17e997105bd1a260850272bfb50e2a328e029c941c2708170d9d978d5a30ad9a"}, + {file = "aiohttp-3.10.5-cp38-cp38-win_amd64.whl", hash = "sha256:1c19de68896747a2aa6257ae4cf6ef59d73917a36a35ee9d0a6f48cff0f94db8"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7e2fe37ac654032db1f3499fe56e77190282534810e2a8e833141a021faaab0e"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5bf3ead3cb66ab990ee2561373b009db5bc0e857549b6c9ba84b20bc462e172"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b2c16a919d936ca87a3c5f0e43af12a89a3ce7ccbce59a2d6784caba945b68b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad146dae5977c4dd435eb31373b3fe9b0b1bf26858c6fc452bf6af394067e10b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c5c6fa16412b35999320f5c9690c0f554392dc222c04e559217e0f9ae244b92"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:95c4dc6f61d610bc0ee1edc6f29d993f10febfe5b76bb470b486d90bbece6b22"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da452c2c322e9ce0cfef392e469a26d63d42860f829026a63374fde6b5c5876f"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:898715cf566ec2869d5cb4d5fb4be408964704c46c96b4be267442d265390f32"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:391cc3a9c1527e424c6865e087897e766a917f15dddb360174a70467572ac6ce"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:380f926b51b92d02a34119d072f178d80bbda334d1a7e10fa22d467a66e494db"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce91db90dbf37bb6fa0997f26574107e1b9d5ff939315247b7e615baa8ec313b"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9093a81e18c45227eebe4c16124ebf3e0d893830c6aca7cc310bfca8fe59d857"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ee40b40aa753d844162dcc80d0fe256b87cba48ca0054f64e68000453caead11"}, + {file = "aiohttp-3.10.5-cp39-cp39-win32.whl", hash = "sha256:03f2645adbe17f274444953bdea69f8327e9d278d961d85657cb0d06864814c1"}, + {file = "aiohttp-3.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:d17920f18e6ee090bdd3d0bfffd769d9f2cb4c8ffde3eb203777a3895c128862"}, + {file = "aiohttp-3.10.5.tar.gz", hash = "sha256:f071854b47d39591ce9a17981c46790acb30518e2f83dfca8db2dfa091178691"}, ] [package.dependencies] @@ -335,13 +350,13 @@ files = [ [[package]] name = "attrs" -version = "24.1.0" +version = "24.2.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-24.1.0-py3-none-any.whl", hash = "sha256:377b47448cb61fea38533f671fba0d0f8a96fd58facd4dc518e3dac9dbea0905"}, - {file = "attrs-24.1.0.tar.gz", hash = "sha256:adbdec84af72d38be7628e353a09b6a6790d15cd71819f6e9d7b0faa8a125745"}, + {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, + {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, ] [package.extras] @@ -415,6 +430,29 @@ msal = ">=1.24.0" msal-extensions = ">=0.3.0" typing-extensions = ">=4.0.0" +[[package]] +name = "azure-kusto-data" +version = "4.5.1" +description = "Kusto Data Client" +optional = false +python-versions = "*" +files = [ + {file = "azure-kusto-data-4.5.1.tar.gz", hash = "sha256:13675546254021bba4fccba801c15ca8f18037510f527fe518001b544e1335ce"}, + {file = "azure_kusto_data-4.5.1-py2.py3-none-any.whl", hash = "sha256:8ae99bd2a88793345d2cbe32f2c14ce27a2d1e3adcee29a718b7e7c3ccb1fcbd"}, +] + +[package.dependencies] +azure-core = ">=1.11.0,<2" +azure-identity = ">=1.5.0,<2" +ijson = ">=3.1,<4.0" +msal = ">=1.9.0,<2" +python-dateutil = ">=2.8.0" +requests = ">=2.13.0" + +[package.extras] +aio = ["aiohttp (>=3.8.0,<4)", "asgiref (>=3.2.3,<4)"] +pandas = ["pandas"] + [[package]] name = "azure-search-documents" version = "11.5.1" @@ -434,13 +472,13 @@ typing-extensions = ">=4.6.0" [[package]] name = "azure-storage-blob" -version = "12.21.0" +version = "12.22.0" description = "Microsoft Azure Blob Storage Client Library for Python" optional = false python-versions = ">=3.8" files = [ - {file = "azure-storage-blob-12.21.0.tar.gz", hash = "sha256:b9722725072f5b7373c0f4dd6d78fbae2bb37bffc5c3e01731ab8c750ee8dd7e"}, - {file = "azure_storage_blob-12.21.0-py3-none-any.whl", hash = "sha256:f9ede187dd5a0ef296b583a7c1861c6938ddd6708d6e70f4203a163c2ab42d43"}, + {file = "azure-storage-blob-12.22.0.tar.gz", hash = "sha256:b3804bb4fe8ab1c32771fa464053da772a682c2737b19da438a3f4e5e3b3736e"}, + {file = "azure_storage_blob-12.22.0-py3-none-any.whl", hash = "sha256:bb7d2d824ce3f11f14a27ee7d9281289f7e072ac8311c52e3652672455b7d5e8"}, ] [package.dependencies] @@ -454,13 +492,13 @@ aio = ["azure-core[aio] (>=1.28.0)"] [[package]] name = "babel" -version = "2.15.0" +version = "2.16.0" description = "Internationalization utilities" optional = false python-versions = ">=3.8" files = [ - {file = "Babel-2.15.0-py3-none-any.whl", hash = "sha256:08706bdad8d0a3413266ab61bd6c34d0c28d6e1e7badf40a2cebe67644e2e1fb"}, - {file = "babel-2.15.0.tar.gz", hash = "sha256:8daf0e265d05768bc6c7a314cf1321e9a123afc328cc635c18622a2f30a04413"}, + {file = "babel-2.16.0-py3-none-any.whl", hash = "sha256:368b5b98b37c06b7daf6696391c3240c938b37767d4584413e8438c5c435fa8b"}, + {file = "babel-2.16.0.tar.gz", hash = "sha256:d1f3554ca26605fe173f3de0c65f750f5a42f924499bf134de6423582298e316"}, ] [package.extras] @@ -525,13 +563,13 @@ css = ["tinycss2 (>=1.1.0,<1.3)"] [[package]] name = "cachetools" -version = "5.4.0" +version = "5.5.0" description = "Extensible memoizing collections and decorators" optional = false python-versions = ">=3.7" files = [ - {file = "cachetools-5.4.0-py3-none-any.whl", hash = "sha256:3ae3b49a3d5e28a77a0be2b37dbcb89005058959cb2323858c2657c4a8cab474"}, - {file = "cachetools-5.4.0.tar.gz", hash = "sha256:b8adc2e7c07f105ced7bc56dbb6dfbe7c4a00acce20e2227b3f355be89bc6827"}, + {file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"}, + {file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"}, ] [[package]] @@ -547,63 +585,78 @@ files = [ [[package]] name = "cffi" -version = "1.16.0" +version = "1.17.0" description = "Foreign Function Interface for Python calling C code." optional = false python-versions = ">=3.8" files = [ - {file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"}, - {file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614"}, - {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743"}, - {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d"}, - {file = "cffi-1.16.0-cp310-cp310-win32.whl", hash = "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a"}, - {file = "cffi-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1"}, - {file = "cffi-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b84834d0cf97e7d27dd5b7f3aca7b6e9263c56308ab9dc8aae9784abb774d404"}, - {file = "cffi-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b8ebc27c014c59692bb2664c7d13ce7a6e9a629be20e54e7271fa696ff2b417"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ee07e47c12890ef248766a6e55bd38ebfb2bb8edd4142d56db91b21ea68b7627"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8a9d3ebe49f084ad71f9269834ceccbf398253c9fac910c4fd7053ff1386936"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e70f54f1796669ef691ca07d046cd81a29cb4deb1e5f942003f401c0c4a2695d"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5bf44d66cdf9e893637896c7faa22298baebcd18d1ddb6d2626a6e39793a1d56"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b78010e7b97fef4bee1e896df8a4bbb6712b7f05b7ef630f9d1da00f6444d2e"}, - {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c6a164aa47843fb1b01e941d385aab7215563bb8816d80ff3a363a9f8448a8dc"}, - {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e09f3ff613345df5e8c3667da1d918f9149bd623cd9070c983c013792a9a62eb"}, - {file = "cffi-1.16.0-cp311-cp311-win32.whl", hash = "sha256:2c56b361916f390cd758a57f2e16233eb4f64bcbeee88a4881ea90fca14dc6ab"}, - {file = "cffi-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:db8e577c19c0fda0beb7e0d4e09e0ba74b1e4c092e0e40bfa12fe05b6f6d75ba"}, - {file = "cffi-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956"}, - {file = "cffi-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969"}, - {file = "cffi-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520"}, - {file = "cffi-1.16.0-cp312-cp312-win32.whl", hash = "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b"}, - {file = "cffi-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235"}, - {file = "cffi-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a09582f178759ee8128d9270cd1344154fd473bb77d94ce0aeb2a93ebf0feaf0"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e760191dd42581e023a68b758769e2da259b5d52e3103c6060ddc02c9edb8d7b"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80876338e19c951fdfed6198e70bc88f1c9758b94578d5a7c4c91a87af3cf31c"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6a14b17d7e17fa0d207ac08642c8820f84f25ce17a442fd15e27ea18d67c59b"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6602bc8dc6f3a9e02b6c22c4fc1e47aa50f8f8e6d3f78a5e16ac33ef5fefa324"}, - {file = "cffi-1.16.0-cp38-cp38-win32.whl", hash = "sha256:131fd094d1065b19540c3d72594260f118b231090295d8c34e19a7bbcf2e860a"}, - {file = "cffi-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:31d13b0f99e0836b7ff893d37af07366ebc90b678b6664c955b54561fc36ef36"}, - {file = "cffi-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:582215a0e9adbe0e379761260553ba11c58943e4bbe9c36430c4ca6ac74b15ed"}, - {file = "cffi-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b29ebffcf550f9da55bec9e02ad430c992a87e5f512cd63388abb76f1036d8d2"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc9b18bf40cc75f66f40a7379f6a9513244fe33c0e8aa72e2d56b0196a7ef872"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cb4a35b3642fc5c005a6755a5d17c6c8b6bcb6981baf81cea8bfbc8903e8ba8"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b86851a328eedc692acf81fb05444bdf1891747c25af7529e39ddafaf68a4f3f"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0f31130ebc2d37cdd8e44605fb5fa7ad59049298b3f745c74fa74c62fbfcfc4"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098"}, - {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:748dcd1e3d3d7cd5443ef03ce8685043294ad6bd7c02a38d1bd367cfd968e000"}, - {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8895613bcc094d4a1b2dbe179d88d7fb4a15cee43c052e8885783fac397d91fe"}, - {file = "cffi-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed86a35631f7bfbb28e108dd96773b9d5a6ce4811cf6ea468bb6a359b256b1e4"}, - {file = "cffi-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:3686dffb02459559c74dd3d81748269ffb0eb027c39a6fc99502de37d501faa8"}, - {file = "cffi-1.16.0.tar.gz", hash = "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0"}, + {file = "cffi-1.17.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f9338cc05451f1942d0d8203ec2c346c830f8e86469903d5126c1f0a13a2bcbb"}, + {file = "cffi-1.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a0ce71725cacc9ebf839630772b07eeec220cbb5f03be1399e0457a1464f8e1a"}, + {file = "cffi-1.17.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c815270206f983309915a6844fe994b2fa47e5d05c4c4cef267c3b30e34dbe42"}, + {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6bdcd415ba87846fd317bee0774e412e8792832e7805938987e4ede1d13046d"}, + {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a98748ed1a1df4ee1d6f927e151ed6c1a09d5ec21684de879c7ea6aa96f58f2"}, + {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0a048d4f6630113e54bb4b77e315e1ba32a5a31512c31a273807d0027a7e69ab"}, + {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24aa705a5f5bd3a8bcfa4d123f03413de5d86e497435693b638cbffb7d5d8a1b"}, + {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:856bf0924d24e7f93b8aee12a3a1095c34085600aa805693fb7f5d1962393206"}, + {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:4304d4416ff032ed50ad6bb87416d802e67139e31c0bde4628f36a47a3164bfa"}, + {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:331ad15c39c9fe9186ceaf87203a9ecf5ae0ba2538c9e898e3a6967e8ad3db6f"}, + {file = "cffi-1.17.0-cp310-cp310-win32.whl", hash = "sha256:669b29a9eca6146465cc574659058ed949748f0809a2582d1f1a324eb91054dc"}, + {file = "cffi-1.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:48b389b1fd5144603d61d752afd7167dfd205973a43151ae5045b35793232aa2"}, + {file = "cffi-1.17.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c5d97162c196ce54af6700949ddf9409e9833ef1003b4741c2b39ef46f1d9720"}, + {file = "cffi-1.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5ba5c243f4004c750836f81606a9fcb7841f8874ad8f3bf204ff5e56332b72b9"}, + {file = "cffi-1.17.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb9333f58fc3a2296fb1d54576138d4cf5d496a2cc118422bd77835e6ae0b9cb"}, + {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:435a22d00ec7d7ea533db494da8581b05977f9c37338c80bc86314bec2619424"}, + {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1df34588123fcc88c872f5acb6f74ae59e9d182a2707097f9e28275ec26a12d"}, + {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:df8bb0010fdd0a743b7542589223a2816bdde4d94bb5ad67884348fa2c1c67e8"}, + {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8b5b9712783415695663bd463990e2f00c6750562e6ad1d28e072a611c5f2a6"}, + {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ffef8fd58a36fb5f1196919638f73dd3ae0db1a878982b27a9a5a176ede4ba91"}, + {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e67d26532bfd8b7f7c05d5a766d6f437b362c1bf203a3a5ce3593a645e870b8"}, + {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45f7cd36186db767d803b1473b3c659d57a23b5fa491ad83c6d40f2af58e4dbb"}, + {file = "cffi-1.17.0-cp311-cp311-win32.whl", hash = "sha256:a9015f5b8af1bb6837a3fcb0cdf3b874fe3385ff6274e8b7925d81ccaec3c5c9"}, + {file = "cffi-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:b50aaac7d05c2c26dfd50c3321199f019ba76bb650e346a6ef3616306eed67b0"}, + {file = "cffi-1.17.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aec510255ce690d240f7cb23d7114f6b351c733a74c279a84def763660a2c3bc"}, + {file = "cffi-1.17.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2770bb0d5e3cc0e31e7318db06efcbcdb7b31bcb1a70086d3177692a02256f59"}, + {file = "cffi-1.17.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:db9a30ec064129d605d0f1aedc93e00894b9334ec74ba9c6bdd08147434b33eb"}, + {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a47eef975d2b8b721775a0fa286f50eab535b9d56c70a6e62842134cf7841195"}, + {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f3e0992f23bbb0be00a921eae5363329253c3b86287db27092461c887b791e5e"}, + {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6107e445faf057c118d5050560695e46d272e5301feffda3c41849641222a828"}, + {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb862356ee9391dc5a0b3cbc00f416b48c1b9a52d252d898e5b7696a5f9fe150"}, + {file = "cffi-1.17.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c1c13185b90bbd3f8b5963cd8ce7ad4ff441924c31e23c975cb150e27c2bf67a"}, + {file = "cffi-1.17.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:17c6d6d3260c7f2d94f657e6872591fe8733872a86ed1345bda872cfc8c74885"}, + {file = "cffi-1.17.0-cp312-cp312-win32.whl", hash = "sha256:c3b8bd3133cd50f6b637bb4322822c94c5ce4bf0d724ed5ae70afce62187c492"}, + {file = "cffi-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:dca802c8db0720ce1c49cce1149ff7b06e91ba15fa84b1d59144fef1a1bc7ac2"}, + {file = "cffi-1.17.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6ce01337d23884b21c03869d2f68c5523d43174d4fc405490eb0091057943118"}, + {file = "cffi-1.17.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cab2eba3830bf4f6d91e2d6718e0e1c14a2f5ad1af68a89d24ace0c6b17cced7"}, + {file = "cffi-1.17.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:14b9cbc8f7ac98a739558eb86fabc283d4d564dafed50216e7f7ee62d0d25377"}, + {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b00e7bcd71caa0282cbe3c90966f738e2db91e64092a877c3ff7f19a1628fdcb"}, + {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:41f4915e09218744d8bae14759f983e466ab69b178de38066f7579892ff2a555"}, + {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4760a68cab57bfaa628938e9c2971137e05ce48e762a9cb53b76c9b569f1204"}, + {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:011aff3524d578a9412c8b3cfaa50f2c0bd78e03eb7af7aa5e0df59b158efb2f"}, + {file = "cffi-1.17.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:a003ac9edc22d99ae1286b0875c460351f4e101f8c9d9d2576e78d7e048f64e0"}, + {file = "cffi-1.17.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ef9528915df81b8f4c7612b19b8628214c65c9b7f74db2e34a646a0a2a0da2d4"}, + {file = "cffi-1.17.0-cp313-cp313-win32.whl", hash = "sha256:70d2aa9fb00cf52034feac4b913181a6e10356019b18ef89bc7c12a283bf5f5a"}, + {file = "cffi-1.17.0-cp313-cp313-win_amd64.whl", hash = "sha256:b7b6ea9e36d32582cda3465f54c4b454f62f23cb083ebc7a94e2ca6ef011c3a7"}, + {file = "cffi-1.17.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:964823b2fc77b55355999ade496c54dde161c621cb1f6eac61dc30ed1b63cd4c"}, + {file = "cffi-1.17.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:516a405f174fd3b88829eabfe4bb296ac602d6a0f68e0d64d5ac9456194a5b7e"}, + {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dec6b307ce928e8e112a6bb9921a1cb00a0e14979bf28b98e084a4b8a742bd9b"}, + {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4094c7b464cf0a858e75cd14b03509e84789abf7b79f8537e6a72152109c76e"}, + {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2404f3de742f47cb62d023f0ba7c5a916c9c653d5b368cc966382ae4e57da401"}, + {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa9d43b02a0c681f0bfbc12d476d47b2b2b6a3f9287f11ee42989a268a1833c"}, + {file = "cffi-1.17.0-cp38-cp38-win32.whl", hash = "sha256:0bb15e7acf8ab35ca8b24b90af52c8b391690ef5c4aec3d31f38f0d37d2cc499"}, + {file = "cffi-1.17.0-cp38-cp38-win_amd64.whl", hash = "sha256:93a7350f6706b31f457c1457d3a3259ff9071a66f312ae64dc024f049055f72c"}, + {file = "cffi-1.17.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a2ddbac59dc3716bc79f27906c010406155031a1c801410f1bafff17ea304d2"}, + {file = "cffi-1.17.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6327b572f5770293fc062a7ec04160e89741e8552bf1c358d1a23eba68166759"}, + {file = "cffi-1.17.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbc183e7bef690c9abe5ea67b7b60fdbca81aa8da43468287dae7b5c046107d4"}, + {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bdc0f1f610d067c70aa3737ed06e2726fd9d6f7bfee4a351f4c40b6831f4e82"}, + {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6d872186c1617d143969defeadac5a904e6e374183e07977eedef9c07c8953bf"}, + {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d46ee4764b88b91f16661a8befc6bfb24806d885e27436fdc292ed7e6f6d058"}, + {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f76a90c345796c01d85e6332e81cab6d70de83b829cf1d9762d0a3da59c7932"}, + {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0e60821d312f99d3e1569202518dddf10ae547e799d75aef3bca3a2d9e8ee693"}, + {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:eb09b82377233b902d4c3fbeeb7ad731cdab579c6c6fda1f763cd779139e47c3"}, + {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:24658baf6224d8f280e827f0a50c46ad819ec8ba380a42448e24459daf809cf4"}, + {file = "cffi-1.17.0-cp39-cp39-win32.whl", hash = "sha256:0fdacad9e0d9fc23e519efd5ea24a70348305e8d7d85ecbb1a5fa66dc834e7fb"}, + {file = "cffi-1.17.0-cp39-cp39-win_amd64.whl", hash = "sha256:7cbc78dc018596315d4e7841c8c3a7ae31cc4d638c9b627f87d52e8abaaf2d29"}, + {file = "cffi-1.17.0.tar.gz", hash = "sha256:f3157624b7558b914cb039fd1af735e5e8049a87c817cc215109ad1c8779df76"}, ] [package.dependencies] @@ -1086,18 +1139,18 @@ tests = ["pytest", "pytest-cov", "pytest-xdist"] [[package]] name = "dask" -version = "2024.7.1" +version = "2024.8.1" description = "Parallel PyData with Task Scheduling" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" files = [ - {file = "dask-2024.7.1-py3-none-any.whl", hash = "sha256:dd046840050376c317de90629db5c6197adda820176cf3e2df10c3219d11951f"}, - {file = "dask-2024.7.1.tar.gz", hash = "sha256:dbaef2d50efee841a9d981a218cfeb50392fc9a95e0403b6d680450e4f50d531"}, + {file = "dask-2024.8.1-py3-none-any.whl", hash = "sha256:b8b58cba91dc9c057c8676dcc80b8bc321602b4dfd21529d33b03b55d428e2c3"}, + {file = "dask-2024.8.1.tar.gz", hash = "sha256:4254e43ac8c3affad2b22952f126b00a00f52c87caae91c068d8e395a4ad1a72"}, ] [package.dependencies] click = ">=8.1" -cloudpickle = ">=1.5.0" +cloudpickle = ">=3.0.0" dask-expr = {version = ">=1.1,<1.2", optional = true, markers = "extra == \"dataframe\""} fsspec = ">=2021.09.0" importlib-metadata = {version = ">=4.13.0", markers = "python_version < \"3.12\""} @@ -1113,22 +1166,22 @@ array = ["numpy (>=1.21)"] complete = ["dask[array,dataframe,diagnostics,distributed]", "lz4 (>=4.3.2)", "pyarrow (>=7.0)", "pyarrow-hotfix"] dataframe = ["dask-expr (>=1.1,<1.2)", "dask[array]", "pandas (>=2.0)"] diagnostics = ["bokeh (>=2.4.2)", "jinja2 (>=2.10.3)"] -distributed = ["distributed (==2024.7.1)"] +distributed = ["distributed (==2024.8.1)"] test = ["pandas[test]", "pre-commit", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-timeout", "pytest-xdist"] [[package]] name = "dask-expr" -version = "1.1.9" +version = "1.1.11" description = "High Level Expressions for Dask" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" files = [ - {file = "dask_expr-1.1.9-py3-none-any.whl", hash = "sha256:6a73d2ef8a4b697db476163ebfc661f7f4b6e12c4e483e8bf207e4d00e574f0c"}, - {file = "dask_expr-1.1.9.tar.gz", hash = "sha256:edb4f5a38b70a15beb4cfd449c71526c728ff989d91edec210836aa05556c766"}, + {file = "dask_expr-1.1.11-py3-none-any.whl", hash = "sha256:b9222b3d430152e3af4a1777f66bcee88651f510876cb57c720107d123d9ba63"}, + {file = "dask_expr-1.1.11.tar.gz", hash = "sha256:275689c269f9c30dbaf9d8d7e9d3b5ac5438ea6db73fdbf95b3f4bfb1381bc5a"}, ] [package.dependencies] -dask = "2024.7.1" +dask = "2024.8.1" pandas = ">=2" pyarrow = ">=7.0.0" @@ -1154,33 +1207,33 @@ pyarrow = ">=15.0.0,<16.0.0" [[package]] name = "debugpy" -version = "1.8.2" +version = "1.8.5" description = "An implementation of the Debug Adapter Protocol for Python" optional = false python-versions = ">=3.8" files = [ - {file = "debugpy-1.8.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7ee2e1afbf44b138c005e4380097d92532e1001580853a7cb40ed84e0ef1c3d2"}, - {file = "debugpy-1.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f8c3f7c53130a070f0fc845a0f2cee8ed88d220d6b04595897b66605df1edd6"}, - {file = "debugpy-1.8.2-cp310-cp310-win32.whl", hash = "sha256:f179af1e1bd4c88b0b9f0fa153569b24f6b6f3de33f94703336363ae62f4bf47"}, - {file = "debugpy-1.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:0600faef1d0b8d0e85c816b8bb0cb90ed94fc611f308d5fde28cb8b3d2ff0fe3"}, - {file = "debugpy-1.8.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:8a13417ccd5978a642e91fb79b871baded925d4fadd4dfafec1928196292aa0a"}, - {file = "debugpy-1.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acdf39855f65c48ac9667b2801234fc64d46778021efac2de7e50907ab90c634"}, - {file = "debugpy-1.8.2-cp311-cp311-win32.whl", hash = "sha256:2cbd4d9a2fc5e7f583ff9bf11f3b7d78dfda8401e8bb6856ad1ed190be4281ad"}, - {file = "debugpy-1.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:d3408fddd76414034c02880e891ea434e9a9cf3a69842098ef92f6e809d09afa"}, - {file = "debugpy-1.8.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:5d3ccd39e4021f2eb86b8d748a96c766058b39443c1f18b2dc52c10ac2757835"}, - {file = "debugpy-1.8.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:62658aefe289598680193ff655ff3940e2a601765259b123dc7f89c0239b8cd3"}, - {file = "debugpy-1.8.2-cp312-cp312-win32.whl", hash = "sha256:bd11fe35d6fd3431f1546d94121322c0ac572e1bfb1f6be0e9b8655fb4ea941e"}, - {file = "debugpy-1.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:15bc2f4b0f5e99bf86c162c91a74c0631dbd9cef3c6a1d1329c946586255e859"}, - {file = "debugpy-1.8.2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:5a019d4574afedc6ead1daa22736c530712465c0c4cd44f820d803d937531b2d"}, - {file = "debugpy-1.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40f062d6877d2e45b112c0bbade9a17aac507445fd638922b1a5434df34aed02"}, - {file = "debugpy-1.8.2-cp38-cp38-win32.whl", hash = "sha256:c78ba1680f1015c0ca7115671fe347b28b446081dada3fedf54138f44e4ba031"}, - {file = "debugpy-1.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:cf327316ae0c0e7dd81eb92d24ba8b5e88bb4d1b585b5c0d32929274a66a5210"}, - {file = "debugpy-1.8.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:1523bc551e28e15147815d1397afc150ac99dbd3a8e64641d53425dba57b0ff9"}, - {file = "debugpy-1.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e24ccb0cd6f8bfaec68d577cb49e9c680621c336f347479b3fce060ba7c09ec1"}, - {file = "debugpy-1.8.2-cp39-cp39-win32.whl", hash = "sha256:7f8d57a98c5a486c5c7824bc0b9f2f11189d08d73635c326abef268f83950326"}, - {file = "debugpy-1.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:16c8dcab02617b75697a0a925a62943e26a0330da076e2a10437edd9f0bf3755"}, - {file = "debugpy-1.8.2-py2.py3-none-any.whl", hash = "sha256:16e16df3a98a35c63c3ab1e4d19be4cbc7fdda92d9ddc059294f18910928e0ca"}, - {file = "debugpy-1.8.2.zip", hash = "sha256:95378ed08ed2089221896b9b3a8d021e642c24edc8fef20e5d4342ca8be65c00"}, + {file = "debugpy-1.8.5-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:7e4d594367d6407a120b76bdaa03886e9eb652c05ba7f87e37418426ad2079f7"}, + {file = "debugpy-1.8.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4413b7a3ede757dc33a273a17d685ea2b0c09dbd312cc03f5534a0fd4d40750a"}, + {file = "debugpy-1.8.5-cp310-cp310-win32.whl", hash = "sha256:dd3811bd63632bb25eda6bd73bea8e0521794cda02be41fa3160eb26fc29e7ed"}, + {file = "debugpy-1.8.5-cp310-cp310-win_amd64.whl", hash = "sha256:b78c1250441ce893cb5035dd6f5fc12db968cc07f91cc06996b2087f7cefdd8e"}, + {file = "debugpy-1.8.5-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:606bccba19f7188b6ea9579c8a4f5a5364ecd0bf5a0659c8a5d0e10dcee3032a"}, + {file = "debugpy-1.8.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db9fb642938a7a609a6c865c32ecd0d795d56c1aaa7a7a5722d77855d5e77f2b"}, + {file = "debugpy-1.8.5-cp311-cp311-win32.whl", hash = "sha256:4fbb3b39ae1aa3e5ad578f37a48a7a303dad9a3d018d369bc9ec629c1cfa7408"}, + {file = "debugpy-1.8.5-cp311-cp311-win_amd64.whl", hash = "sha256:345d6a0206e81eb68b1493ce2fbffd57c3088e2ce4b46592077a943d2b968ca3"}, + {file = "debugpy-1.8.5-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:5b5c770977c8ec6c40c60d6f58cacc7f7fe5a45960363d6974ddb9b62dbee156"}, + {file = "debugpy-1.8.5-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a65b00b7cdd2ee0c2cf4c7335fef31e15f1b7056c7fdbce9e90193e1a8c8cb"}, + {file = "debugpy-1.8.5-cp312-cp312-win32.whl", hash = "sha256:c9f7c15ea1da18d2fcc2709e9f3d6de98b69a5b0fff1807fb80bc55f906691f7"}, + {file = "debugpy-1.8.5-cp312-cp312-win_amd64.whl", hash = "sha256:28ced650c974aaf179231668a293ecd5c63c0a671ae6d56b8795ecc5d2f48d3c"}, + {file = "debugpy-1.8.5-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:3df6692351172a42af7558daa5019651f898fc67450bf091335aa8a18fbf6f3a"}, + {file = "debugpy-1.8.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1cd04a73eb2769eb0bfe43f5bfde1215c5923d6924b9b90f94d15f207a402226"}, + {file = "debugpy-1.8.5-cp38-cp38-win32.whl", hash = "sha256:8f913ee8e9fcf9d38a751f56e6de12a297ae7832749d35de26d960f14280750a"}, + {file = "debugpy-1.8.5-cp38-cp38-win_amd64.whl", hash = "sha256:a697beca97dad3780b89a7fb525d5e79f33821a8bc0c06faf1f1289e549743cf"}, + {file = "debugpy-1.8.5-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:0a1029a2869d01cb777216af8c53cda0476875ef02a2b6ff8b2f2c9a4b04176c"}, + {file = "debugpy-1.8.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84c276489e141ed0b93b0af648eef891546143d6a48f610945416453a8ad406"}, + {file = "debugpy-1.8.5-cp39-cp39-win32.whl", hash = "sha256:ad84b7cde7fd96cf6eea34ff6c4a1b7887e0fe2ea46e099e53234856f9d99a34"}, + {file = "debugpy-1.8.5-cp39-cp39-win_amd64.whl", hash = "sha256:7b0fe36ed9d26cb6836b0a51453653f8f2e347ba7348f2bbfe76bfeb670bfb1c"}, + {file = "debugpy-1.8.5-py2.py3-none-any.whl", hash = "sha256:55919dce65b471eff25901acf82d328bbd5b833526b6c1364bd5133754777a44"}, + {file = "debugpy-1.8.5.zip", hash = "sha256:b2112cfeb34b4507399d298fe7023a16656fc553ed5246536060ca7bd0e668d0"}, ] [[package]] @@ -1777,15 +1830,118 @@ files = [ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] +[[package]] +name = "ijson" +version = "3.3.0" +description = "Iterative JSON parser with standard Python iterator interfaces" +optional = false +python-versions = "*" +files = [ + {file = "ijson-3.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7f7a5250599c366369fbf3bc4e176f5daa28eb6bc7d6130d02462ed335361675"}, + {file = "ijson-3.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f87a7e52f79059f9c58f6886c262061065eb6f7554a587be7ed3aa63e6b71b34"}, + {file = "ijson-3.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b73b493af9e947caed75d329676b1b801d673b17481962823a3e55fe529c8b8b"}, + {file = "ijson-3.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5576415f3d76290b160aa093ff968f8bf6de7d681e16e463a0134106b506f49"}, + {file = "ijson-3.3.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4e9ffe358d5fdd6b878a8a364e96e15ca7ca57b92a48f588378cef315a8b019e"}, + {file = "ijson-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8643c255a25824ddd0895c59f2319c019e13e949dc37162f876c41a283361527"}, + {file = "ijson-3.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:df3ab5e078cab19f7eaeef1d5f063103e1ebf8c26d059767b26a6a0ad8b250a3"}, + {file = "ijson-3.3.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3dc1fb02c6ed0bae1b4bf96971258bf88aea72051b6e4cebae97cff7090c0607"}, + {file = "ijson-3.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e9afd97339fc5a20f0542c971f90f3ca97e73d3050cdc488d540b63fae45329a"}, + {file = "ijson-3.3.0-cp310-cp310-win32.whl", hash = "sha256:844c0d1c04c40fd1b60f148dc829d3f69b2de789d0ba239c35136efe9a386529"}, + {file = "ijson-3.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:d654d045adafdcc6c100e8e911508a2eedbd2a1b5f93f930ba13ea67d7704ee9"}, + {file = "ijson-3.3.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:501dce8eaa537e728aa35810656aa00460a2547dcb60937c8139f36ec344d7fc"}, + {file = "ijson-3.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:658ba9cad0374d37b38c9893f4864f284cdcc7d32041f9808fba8c7bcaadf134"}, + {file = "ijson-3.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2636cb8c0f1023ef16173f4b9a233bcdb1df11c400c603d5f299fac143ca8d70"}, + {file = "ijson-3.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd174b90db68c3bcca273e9391934a25d76929d727dc75224bf244446b28b03b"}, + {file = "ijson-3.3.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:97a9aea46e2a8371c4cf5386d881de833ed782901ac9f67ebcb63bb3b7d115af"}, + {file = "ijson-3.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c594c0abe69d9d6099f4ece17763d53072f65ba60b372d8ba6de8695ce6ee39e"}, + {file = "ijson-3.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8e0ff16c224d9bfe4e9e6bd0395826096cda4a3ef51e6c301e1b61007ee2bd24"}, + {file = "ijson-3.3.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0015354011303175eae7e2ef5136414e91de2298e5a2e9580ed100b728c07e51"}, + {file = "ijson-3.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:034642558afa57351a0ffe6de89e63907c4cf6849070cc10a3b2542dccda1afe"}, + {file = "ijson-3.3.0-cp311-cp311-win32.whl", hash = "sha256:192e4b65495978b0bce0c78e859d14772e841724d3269fc1667dc6d2f53cc0ea"}, + {file = "ijson-3.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:72e3488453754bdb45c878e31ce557ea87e1eb0f8b4fc610373da35e8074ce42"}, + {file = "ijson-3.3.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:988e959f2f3d59ebd9c2962ae71b97c0df58323910d0b368cc190ad07429d1bb"}, + {file = "ijson-3.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b2f73f0d0fce5300f23a1383d19b44d103bb113b57a69c36fd95b7c03099b181"}, + {file = "ijson-3.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0ee57a28c6bf523d7cb0513096e4eb4dac16cd935695049de7608ec110c2b751"}, + {file = "ijson-3.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0155a8f079c688c2ccaea05de1ad69877995c547ba3d3612c1c336edc12a3a5"}, + {file = "ijson-3.3.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ab00721304af1ae1afa4313ecfa1bf16b07f55ef91e4a5b93aeaa3e2bd7917c"}, + {file = "ijson-3.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40ee3821ee90be0f0e95dcf9862d786a7439bd1113e370736bfdf197e9765bfb"}, + {file = "ijson-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:da3b6987a0bc3e6d0f721b42c7a0198ef897ae50579547b0345f7f02486898f5"}, + {file = "ijson-3.3.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:63afea5f2d50d931feb20dcc50954e23cef4127606cc0ecf7a27128ed9f9a9e6"}, + {file = "ijson-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b5c3e285e0735fd8c5a26d177eca8b52512cdd8687ca86ec77a0c66e9c510182"}, + {file = "ijson-3.3.0-cp312-cp312-win32.whl", hash = "sha256:907f3a8674e489abdcb0206723e5560a5cb1fa42470dcc637942d7b10f28b695"}, + {file = "ijson-3.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:8f890d04ad33262d0c77ead53c85f13abfb82f2c8f078dfbf24b78f59534dfdd"}, + {file = "ijson-3.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b9d85a02e77ee8ea6d9e3fd5d515bcc3d798d9c1ea54817e5feb97a9bc5d52fe"}, + {file = "ijson-3.3.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6576cdc36d5a09b0c1a3d81e13a45d41a6763188f9eaae2da2839e8a4240bce"}, + {file = "ijson-3.3.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5589225c2da4bb732c9c370c5961c39a6db72cf69fb2a28868a5413ed7f39e6"}, + {file = "ijson-3.3.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad04cf38164d983e85f9cba2804566c0160b47086dcca4cf059f7e26c5ace8ca"}, + {file = "ijson-3.3.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:a3b730ef664b2ef0e99dec01b6573b9b085c766400af363833e08ebc1e38eb2f"}, + {file = "ijson-3.3.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:4690e3af7b134298055993fcbea161598d23b6d3ede11b12dca6815d82d101d5"}, + {file = "ijson-3.3.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:aaa6bfc2180c31a45fac35d40e3312a3d09954638ce0b2e9424a88e24d262a13"}, + {file = "ijson-3.3.0-cp36-cp36m-win32.whl", hash = "sha256:44367090a5a876809eb24943f31e470ba372aaa0d7396b92b953dda953a95d14"}, + {file = "ijson-3.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:7e2b3e9ca957153557d06c50a26abaf0d0d6c0ddf462271854c968277a6b5372"}, + {file = "ijson-3.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:47c144117e5c0e2babb559bc8f3f76153863b8dd90b2d550c51dab5f4b84a87f"}, + {file = "ijson-3.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29ce02af5fbf9ba6abb70765e66930aedf73311c7d840478f1ccecac53fefbf3"}, + {file = "ijson-3.3.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4ac6c3eeed25e3e2cb9b379b48196413e40ac4e2239d910bb33e4e7f6c137745"}, + {file = "ijson-3.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d92e339c69b585e7b1d857308ad3ca1636b899e4557897ccd91bb9e4a56c965b"}, + {file = "ijson-3.3.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:8c85447569041939111b8c7dbf6f8fa7a0eb5b2c4aebb3c3bec0fb50d7025121"}, + {file = "ijson-3.3.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:542c1e8fddf082159a5d759ee1412c73e944a9a2412077ed00b303ff796907dc"}, + {file = "ijson-3.3.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:30cfea40936afb33b57d24ceaf60d0a2e3d5c1f2335ba2623f21d560737cc730"}, + {file = "ijson-3.3.0-cp37-cp37m-win32.whl", hash = "sha256:6b661a959226ad0d255e49b77dba1d13782f028589a42dc3172398dd3814c797"}, + {file = "ijson-3.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:0b003501ee0301dbf07d1597482009295e16d647bb177ce52076c2d5e64113e0"}, + {file = "ijson-3.3.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3e8d8de44effe2dbd0d8f3eb9840344b2d5b4cc284a14eb8678aec31d1b6bea8"}, + {file = "ijson-3.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9cd5c03c63ae06d4f876b9844c5898d0044c7940ff7460db9f4cd984ac7862b5"}, + {file = "ijson-3.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04366e7e4a4078d410845e58a2987fd9c45e63df70773d7b6e87ceef771b51ee"}, + {file = "ijson-3.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de7c1ddb80fa7a3ab045266dca169004b93f284756ad198306533b792774f10a"}, + {file = "ijson-3.3.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8851584fb931cffc0caa395f6980525fd5116eab8f73ece9d95e6f9c2c326c4c"}, + {file = "ijson-3.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdcfc88347fd981e53c33d832ce4d3e981a0d696b712fbcb45dcc1a43fe65c65"}, + {file = "ijson-3.3.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3917b2b3d0dbbe3296505da52b3cb0befbaf76119b2edaff30bd448af20b5400"}, + {file = "ijson-3.3.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:e10c14535abc7ddf3fd024aa36563cd8ab5d2bb6234a5d22c77c30e30fa4fb2b"}, + {file = "ijson-3.3.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:3aba5c4f97f4e2ce854b5591a8b0711ca3b0c64d1b253b04ea7b004b0a197ef6"}, + {file = "ijson-3.3.0-cp38-cp38-win32.whl", hash = "sha256:b325f42e26659df1a0de66fdb5cde8dd48613da9c99c07d04e9fb9e254b7ee1c"}, + {file = "ijson-3.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:ff835906f84451e143f31c4ce8ad73d83ef4476b944c2a2da91aec8b649570e1"}, + {file = "ijson-3.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3c556f5553368dff690c11d0a1fb435d4ff1f84382d904ccc2dc53beb27ba62e"}, + {file = "ijson-3.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e4396b55a364a03ff7e71a34828c3ed0c506814dd1f50e16ebed3fc447d5188e"}, + {file = "ijson-3.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e6850ae33529d1e43791b30575070670070d5fe007c37f5d06aebc1dd152ab3f"}, + {file = "ijson-3.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36aa56d68ea8def26778eb21576ae13f27b4a47263a7a2581ab2ef58b8de4451"}, + {file = "ijson-3.3.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7ec759c4a0fc820ad5dc6a58e9c391e7b16edcb618056baedbedbb9ea3b1524"}, + {file = "ijson-3.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b51bab2c4e545dde93cb6d6bb34bf63300b7cd06716f195dd92d9255df728331"}, + {file = "ijson-3.3.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:92355f95a0e4da96d4c404aa3cff2ff033f9180a9515f813255e1526551298c1"}, + {file = "ijson-3.3.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:8795e88adff5aa3c248c1edce932db003d37a623b5787669ccf205c422b91e4a"}, + {file = "ijson-3.3.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:8f83f553f4cde6d3d4eaf58ec11c939c94a0ec545c5b287461cafb184f4b3a14"}, + {file = "ijson-3.3.0-cp39-cp39-win32.whl", hash = "sha256:ead50635fb56577c07eff3e557dac39533e0fe603000684eea2af3ed1ad8f941"}, + {file = "ijson-3.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:c8a9befb0c0369f0cf5c1b94178d0d78f66d9cebb9265b36be6e4f66236076b8"}, + {file = "ijson-3.3.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2af323a8aec8a50fa9effa6d640691a30a9f8c4925bd5364a1ca97f1ac6b9b5c"}, + {file = "ijson-3.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f64f01795119880023ba3ce43072283a393f0b90f52b66cc0ea1a89aa64a9ccb"}, + {file = "ijson-3.3.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a716e05547a39b788deaf22725490855337fc36613288aa8ae1601dc8c525553"}, + {file = "ijson-3.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:473f5d921fadc135d1ad698e2697025045cd8ed7e5e842258295012d8a3bc702"}, + {file = "ijson-3.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd26b396bc3a1e85f4acebeadbf627fa6117b97f4c10b177d5779577c6607744"}, + {file = "ijson-3.3.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:25fd49031cdf5fd5f1fd21cb45259a64dad30b67e64f745cc8926af1c8c243d3"}, + {file = "ijson-3.3.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b72178b1e565d06ab19319965022b36ef41bcea7ea153b32ec31194bec032a2"}, + {file = "ijson-3.3.0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d0b6b637d05dbdb29d0bfac2ed8425bb369e7af5271b0cc7cf8b801cb7360c2"}, + {file = "ijson-3.3.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5378d0baa59ae422905c5f182ea0fd74fe7e52a23e3821067a7d58c8306b2191"}, + {file = "ijson-3.3.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:99f5c8ab048ee4233cc4f2b461b205cbe01194f6201018174ac269bf09995749"}, + {file = "ijson-3.3.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:45ff05de889f3dc3d37a59d02096948ce470699f2368b32113954818b21aa74a"}, + {file = "ijson-3.3.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1efb521090dd6cefa7aafd120581947b29af1713c902ff54336b7c7130f04c47"}, + {file = "ijson-3.3.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87c727691858fd3a1c085d9980d12395517fcbbf02c69fbb22dede8ee03422da"}, + {file = "ijson-3.3.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0420c24e50389bc251b43c8ed379ab3e3ba065ac8262d98beb6735ab14844460"}, + {file = "ijson-3.3.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:8fdf3721a2aa7d96577970f5604bd81f426969c1822d467f07b3d844fa2fecc7"}, + {file = "ijson-3.3.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:891f95c036df1bc95309951940f8eea8537f102fa65715cdc5aae20b8523813b"}, + {file = "ijson-3.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed1336a2a6e5c427f419da0154e775834abcbc8ddd703004108121c6dd9eba9d"}, + {file = "ijson-3.3.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0c819f83e4f7b7f7463b2dc10d626a8be0c85fbc7b3db0edc098c2b16ac968e"}, + {file = "ijson-3.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33afc25057377a6a43c892de34d229a86f89ea6c4ca3dd3db0dcd17becae0dbb"}, + {file = "ijson-3.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7914d0cf083471856e9bc2001102a20f08e82311dfc8cf1a91aa422f9414a0d6"}, + {file = "ijson-3.3.0.tar.gz", hash = "sha256:7f172e6ba1bee0d4c8f8ebd639577bfe429dee0f3f96775a067b8bae4492d8a0"}, +] + [[package]] name = "importlib-metadata" -version = "8.2.0" +version = "8.4.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-8.2.0-py3-none-any.whl", hash = "sha256:11901fa0c2f97919b288679932bb64febaeacf289d18ac84dd68cb2e74213369"}, - {file = "importlib_metadata-8.2.0.tar.gz", hash = "sha256:72e8d4399996132204f9a16dcc751af254a48f8d1b20b9ff0f98d4a8f901e73d"}, + {file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"}, + {file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"}, ] [package.dependencies] @@ -1963,6 +2119,76 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jiter" +version = "0.5.0" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jiter-0.5.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b599f4e89b3def9a94091e6ee52e1d7ad7bc33e238ebb9c4c63f211d74822c3f"}, + {file = "jiter-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a063f71c4b06225543dddadbe09d203dc0c95ba352d8b85f1221173480a71d5"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:acc0d5b8b3dd12e91dd184b87273f864b363dfabc90ef29a1092d269f18c7e28"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c22541f0b672f4d741382a97c65609332a783501551445ab2df137ada01e019e"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:63314832e302cc10d8dfbda0333a384bf4bcfce80d65fe99b0f3c0da8945a91a"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a25fbd8a5a58061e433d6fae6d5298777c0814a8bcefa1e5ecfff20c594bd749"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:503b2c27d87dfff5ab717a8200fbbcf4714516c9d85558048b1fc14d2de7d8dc"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6d1f3d27cce923713933a844872d213d244e09b53ec99b7a7fdf73d543529d6d"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c95980207b3998f2c3b3098f357994d3fd7661121f30669ca7cb945f09510a87"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:afa66939d834b0ce063f57d9895e8036ffc41c4bd90e4a99631e5f261d9b518e"}, + {file = "jiter-0.5.0-cp310-none-win32.whl", hash = "sha256:f16ca8f10e62f25fd81d5310e852df6649af17824146ca74647a018424ddeccf"}, + {file = "jiter-0.5.0-cp310-none-win_amd64.whl", hash = "sha256:b2950e4798e82dd9176935ef6a55cf6a448b5c71515a556da3f6b811a7844f1e"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d4c8e1ed0ef31ad29cae5ea16b9e41529eb50a7fba70600008e9f8de6376d553"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c6f16e21276074a12d8421692515b3fd6d2ea9c94fd0734c39a12960a20e85f3"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5280e68e7740c8c128d3ae5ab63335ce6d1fb6603d3b809637b11713487af9e6"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:583c57fc30cc1fec360e66323aadd7fc3edeec01289bfafc35d3b9dcb29495e4"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26351cc14507bdf466b5f99aba3df3143a59da75799bf64a53a3ad3155ecded9"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4829df14d656b3fb87e50ae8b48253a8851c707da9f30d45aacab2aa2ba2d614"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a42a4bdcf7307b86cb863b2fb9bb55029b422d8f86276a50487982d99eed7c6e"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04d461ad0aebf696f8da13c99bc1b3e06f66ecf6cfd56254cc402f6385231c06"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e6375923c5f19888c9226582a124b77b622f8fd0018b843c45eeb19d9701c403"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2cec323a853c24fd0472517113768c92ae0be8f8c384ef4441d3632da8baa646"}, + {file = "jiter-0.5.0-cp311-none-win32.whl", hash = "sha256:aa1db0967130b5cab63dfe4d6ff547c88b2a394c3410db64744d491df7f069bb"}, + {file = "jiter-0.5.0-cp311-none-win_amd64.whl", hash = "sha256:aa9d2b85b2ed7dc7697597dcfaac66e63c1b3028652f751c81c65a9f220899ae"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9f664e7351604f91dcdd557603c57fc0d551bc65cc0a732fdacbf73ad335049a"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:044f2f1148b5248ad2c8c3afb43430dccf676c5a5834d2f5089a4e6c5bbd64df"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:702e3520384c88b6e270c55c772d4bd6d7b150608dcc94dea87ceba1b6391248"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:528d742dcde73fad9d63e8242c036ab4a84389a56e04efd854062b660f559544"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8cf80e5fe6ab582c82f0c3331df27a7e1565e2dcf06265afd5173d809cdbf9ba"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:44dfc9ddfb9b51a5626568ef4e55ada462b7328996294fe4d36de02fce42721f"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c451f7922992751a936b96c5f5b9bb9312243d9b754c34b33d0cb72c84669f4e"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:308fce789a2f093dca1ff91ac391f11a9f99c35369117ad5a5c6c4903e1b3e3a"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7f5ad4a7c6b0d90776fdefa294f662e8a86871e601309643de30bf94bb93a64e"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ea189db75f8eca08807d02ae27929e890c7d47599ce3d0a6a5d41f2419ecf338"}, + {file = "jiter-0.5.0-cp312-none-win32.whl", hash = "sha256:e3bbe3910c724b877846186c25fe3c802e105a2c1fc2b57d6688b9f8772026e4"}, + {file = "jiter-0.5.0-cp312-none-win_amd64.whl", hash = "sha256:a586832f70c3f1481732919215f36d41c59ca080fa27a65cf23d9490e75b2ef5"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f04bc2fc50dc77be9d10f73fcc4e39346402ffe21726ff41028f36e179b587e6"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6f433a4169ad22fcb550b11179bb2b4fd405de9b982601914ef448390b2954f3"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad4a6398c85d3a20067e6c69890ca01f68659da94d74c800298581724e426c7e"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6baa88334e7af3f4d7a5c66c3a63808e5efbc3698a1c57626541ddd22f8e4fbf"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ece0a115c05efca597c6d938f88c9357c843f8c245dbbb53361a1c01afd7148"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:335942557162ad372cc367ffaf93217117401bf930483b4b3ebdb1223dbddfa7"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649b0ee97a6e6da174bffcb3c8c051a5935d7d4f2f52ea1583b5b3e7822fbf14"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f4be354c5de82157886ca7f5925dbda369b77344b4b4adf2723079715f823989"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5206144578831a6de278a38896864ded4ed96af66e1e63ec5dd7f4a1fce38a3a"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8120c60f8121ac3d6f072b97ef0e71770cc72b3c23084c72c4189428b1b1d3b6"}, + {file = "jiter-0.5.0-cp38-none-win32.whl", hash = "sha256:6f1223f88b6d76b519cb033a4d3687ca157c272ec5d6015c322fc5b3074d8a5e"}, + {file = "jiter-0.5.0-cp38-none-win_amd64.whl", hash = "sha256:c59614b225d9f434ea8fc0d0bec51ef5fa8c83679afedc0433905994fb36d631"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0af3838cfb7e6afee3f00dc66fa24695199e20ba87df26e942820345b0afc566"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:550b11d669600dbc342364fd4adbe987f14d0bbedaf06feb1b983383dcc4b961"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:489875bf1a0ffb3cb38a727b01e6673f0f2e395b2aad3c9387f94187cb214bbf"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b250ca2594f5599ca82ba7e68785a669b352156260c5362ea1b4e04a0f3e2389"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ea18e01f785c6667ca15407cd6dabbe029d77474d53595a189bdc813347218e"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:462a52be85b53cd9bffd94e2d788a09984274fe6cebb893d6287e1c296d50653"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92cc68b48d50fa472c79c93965e19bd48f40f207cb557a8346daa020d6ba973b"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1c834133e59a8521bc87ebcad773608c6fa6ab5c7a022df24a45030826cf10bc"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab3a71ff31cf2d45cb216dc37af522d335211f3a972d2fe14ea99073de6cb104"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cccd3af9c48ac500c95e1bcbc498020c87e1781ff0345dd371462d67b76643eb"}, + {file = "jiter-0.5.0-cp39-none-win32.whl", hash = "sha256:368084d8d5c4fc40ff7c3cc513c4f73e02c85f6009217922d0823a48ee7adf61"}, + {file = "jiter-0.5.0-cp39-none-win_amd64.whl", hash = "sha256:ce03f7b4129eb72f1687fa11300fbf677b02990618428934662406d2a76742a1"}, + {file = "jiter-0.5.0.tar.gz", hash = "sha256:1d916ba875bcab5c5f7d927df998c4cb694d27dceddf3392e58beaf10563368a"}, +] + [[package]] name = "joblib" version = "1.4.2" @@ -2618,13 +2844,13 @@ files = [ [[package]] name = "marshmallow" -version = "3.21.3" +version = "3.22.0" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = false python-versions = ">=3.8" files = [ - {file = "marshmallow-3.21.3-py3-none-any.whl", hash = "sha256:86ce7fb914aa865001a4b2092c4c2872d13bc347f3d42673272cabfdbad386f1"}, - {file = "marshmallow-3.21.3.tar.gz", hash = "sha256:4f57c5e050a54d66361e826f94fba213eb10b67b2fdb02c3e0343ce207ba1662"}, + {file = "marshmallow-3.22.0-py3-none-any.whl", hash = "sha256:71a2dce49ef901c3f97ed296ae5051135fd3febd2bf43afe0ae9a82143a494d9"}, + {file = "marshmallow-3.22.0.tar.gz", hash = "sha256:4972f529104a220bb8637d595aa4c9762afbe7f7a77d82dc58c1615d70c5823e"}, ] [package.dependencies] @@ -2632,45 +2858,56 @@ packaging = ">=17.0" [package.extras] dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"] -docs = ["alabaster (==0.7.16)", "autodocsumm (==0.2.12)", "sphinx (==7.3.7)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] +docs = ["alabaster (==1.0.0)", "autodocsumm (==0.2.13)", "sphinx (==8.0.2)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] tests = ["pytest", "pytz", "simplejson"] [[package]] name = "matplotlib" -version = "3.9.0" +version = "3.9.2" description = "Python plotting package" optional = false python-versions = ">=3.9" files = [ - {file = "matplotlib-3.9.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2bcee1dffaf60fe7656183ac2190bd630842ff87b3153afb3e384d966b57fe56"}, - {file = "matplotlib-3.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3f988bafb0fa39d1074ddd5bacd958c853e11def40800c5824556eb630f94d3b"}, - {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe428e191ea016bb278758c8ee82a8129c51d81d8c4bc0846c09e7e8e9057241"}, - {file = "matplotlib-3.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf3978060a106fab40c328778b148f590e27f6fa3cd15a19d6892575bce387d"}, - {file = "matplotlib-3.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e7f03e5cbbfacdd48c8ea394d365d91ee8f3cae7e6ec611409927b5ed997ee4"}, - {file = "matplotlib-3.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:13beb4840317d45ffd4183a778685e215939be7b08616f431c7795276e067463"}, - {file = "matplotlib-3.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38"}, - {file = "matplotlib-3.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9a2fa6d899e17ddca6d6526cf6e7ba677738bf2a6a9590d702c277204a7c6152"}, - {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:550cdda3adbd596078cca7d13ed50b77879104e2e46392dcd7c75259d8f00e85"}, - {file = "matplotlib-3.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76cce0f31b351e3551d1f3779420cf8f6ec0d4a8cf9c0237a3b549fd28eb4abb"}, - {file = "matplotlib-3.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c53aeb514ccbbcbab55a27f912d79ea30ab21ee0531ee2c09f13800efb272674"}, - {file = "matplotlib-3.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:a5be985db2596d761cdf0c2eaf52396f26e6a64ab46bd8cd810c48972349d1be"}, - {file = "matplotlib-3.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c79f3a585f1368da6049318bdf1f85568d8d04b2e89fc24b7e02cc9b62017382"}, - {file = "matplotlib-3.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bdd1ecbe268eb3e7653e04f451635f0fb0f77f07fd070242b44c076c9106da84"}, - {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d38e85a1a6d732f645f1403ce5e6727fd9418cd4574521d5803d3d94911038e5"}, - {file = "matplotlib-3.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db"}, - {file = "matplotlib-3.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8146ce83cbc5dc71c223a74a1996d446cd35cfb6a04b683e1446b7e6c73603b7"}, - {file = "matplotlib-3.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:d91a4ffc587bacf5c4ce4ecfe4bcd23a4b675e76315f2866e588686cc97fccdf"}, - {file = "matplotlib-3.9.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:616fabf4981a3b3c5a15cd95eba359c8489c4e20e03717aea42866d8d0465956"}, - {file = "matplotlib-3.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd53c79fd02f1c1808d2cfc87dd3cf4dbc63c5244a58ee7944497107469c8d8a"}, - {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321"}, - {file = "matplotlib-3.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81c40af649d19c85f8073e25e5806926986806fa6d54be506fbf02aef47d5a89"}, - {file = "matplotlib-3.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52146fc3bd7813cc784562cb93a15788be0b2875c4655e2cc6ea646bfa30344b"}, - {file = "matplotlib-3.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:0fc51eaa5262553868461c083d9adadb11a6017315f3a757fc45ec6ec5f02888"}, - {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bd4f2831168afac55b881db82a7730992aa41c4f007f1913465fb182d6fb20c0"}, - {file = "matplotlib-3.9.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:290d304e59be2b33ef5c2d768d0237f5bd132986bdcc66f80bc9bcc300066a03"}, - {file = "matplotlib-3.9.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ff2e239c26be4f24bfa45860c20ffccd118d270c5b5d081fa4ea409b5469fcd"}, - {file = "matplotlib-3.9.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:af4001b7cae70f7eaacfb063db605280058246de590fa7874f00f62259f2df7e"}, - {file = "matplotlib-3.9.0.tar.gz", hash = "sha256:e6d29ea6c19e34b30fb7d88b7081f869a03014f66fe06d62cc77d5a6ea88ed7a"}, + {file = "matplotlib-3.9.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:9d78bbc0cbc891ad55b4f39a48c22182e9bdaea7fc0e5dbd364f49f729ca1bbb"}, + {file = "matplotlib-3.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c375cc72229614632c87355366bdf2570c2dac01ac66b8ad048d2dabadf2d0d4"}, + {file = "matplotlib-3.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d94ff717eb2bd0b58fe66380bd8b14ac35f48a98e7c6765117fe67fb7684e64"}, + {file = "matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab68d50c06938ef28681073327795c5db99bb4666214d2d5f880ed11aeaded66"}, + {file = "matplotlib-3.9.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:65aacf95b62272d568044531e41de26285d54aec8cb859031f511f84bd8b495a"}, + {file = "matplotlib-3.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:3fd595f34aa8a55b7fc8bf9ebea8aa665a84c82d275190a61118d33fbc82ccae"}, + {file = "matplotlib-3.9.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d8dd059447824eec055e829258ab092b56bb0579fc3164fa09c64f3acd478772"}, + {file = "matplotlib-3.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c797dac8bb9c7a3fd3382b16fe8f215b4cf0f22adccea36f1545a6d7be310b41"}, + {file = "matplotlib-3.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d719465db13267bcef19ea8954a971db03b9f48b4647e3860e4bc8e6ed86610f"}, + {file = "matplotlib-3.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8912ef7c2362f7193b5819d17dae8629b34a95c58603d781329712ada83f9447"}, + {file = "matplotlib-3.9.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:7741f26a58a240f43bee74965c4882b6c93df3e7eb3de160126d8c8f53a6ae6e"}, + {file = "matplotlib-3.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:ae82a14dab96fbfad7965403c643cafe6515e386de723e498cf3eeb1e0b70cc7"}, + {file = "matplotlib-3.9.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:ac43031375a65c3196bee99f6001e7fa5bdfb00ddf43379d3c0609bdca042df9"}, + {file = "matplotlib-3.9.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be0fc24a5e4531ae4d8e858a1a548c1fe33b176bb13eff7f9d0d38ce5112a27d"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf81de2926c2db243c9b2cbc3917619a0fc85796c6ba4e58f541df814bbf83c7"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c"}, + {file = "matplotlib-3.9.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:306c8dfc73239f0e72ac50e5a9cf19cc4e8e331dd0c54f5e69ca8758550f1e1e"}, + {file = "matplotlib-3.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:5413401594cfaff0052f9d8b1aafc6d305b4bd7c4331dccd18f561ff7e1d3bd3"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:18128cc08f0d3cfff10b76baa2f296fc28c4607368a8402de61bb3f2eb33c7d9"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4876d7d40219e8ae8bb70f9263bcbe5714415acfdf781086601211335e24f8aa"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d9f07a80deab4bb0b82858a9e9ad53d1382fd122be8cde11080f4e7dfedb38b"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413"}, + {file = "matplotlib-3.9.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:909645cce2dc28b735674ce0931a4ac94e12f5b13f6bb0b5a5e65e7cea2c192b"}, + {file = "matplotlib-3.9.2-cp313-cp313-win_amd64.whl", hash = "sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:37e51dd1c2db16ede9cfd7b5cabdfc818b2c6397c83f8b10e0e797501c963a03"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b82c5045cebcecd8496a4d694d43f9cc84aeeb49fe2133e036b207abe73f4d30"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f053c40f94bc51bc03832a41b4f153d83f2062d88c72b5e79997072594e97e51"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbe196377a8248972f5cede786d4c5508ed5f5ca4a1e09b44bda889958b33f8c"}, + {file = "matplotlib-3.9.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5816b1e1fe8c192cbc013f8f3e3368ac56fbecf02fb41b8f8559303f24c5015e"}, + {file = "matplotlib-3.9.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:cef2a73d06601437be399908cf13aee74e86932a5ccc6ccdf173408ebc5f6bb2"}, + {file = "matplotlib-3.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e0830e188029c14e891fadd99702fd90d317df294c3298aad682739c5533721a"}, + {file = "matplotlib-3.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ba9c1299c920964e8d3857ba27173b4dbb51ca4bab47ffc2c2ba0eb5e2cbc5"}, + {file = "matplotlib-3.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1cd93b91ab47a3616b4d3c42b52f8363b88ca021e340804c6ab2536344fad9ca"}, + {file = "matplotlib-3.9.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:6d1ce5ed2aefcdce11904fc5bbea7d9c21fff3d5f543841edf3dea84451a09ea"}, + {file = "matplotlib-3.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:b2696efdc08648536efd4e1601b5fd491fd47f4db97a5fbfd175549a7365c1b2"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d52a3b618cb1cbb769ce2ee1dcdb333c3ab6e823944e9a2d36e37253815f9556"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:039082812cacd6c6bec8e17a9c1e6baca230d4116d522e81e1f63a74d01d2e21"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6758baae2ed64f2331d4fd19be38b7b4eae3ecec210049a26b6a4f3ae1c85dcc"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:050598c2b29e0b9832cde72bcf97627bf00262adbc4a54e2b856426bb2ef0697"}, + {file = "matplotlib-3.9.2.tar.gz", hash = "sha256:96ab43906269ca64a6366934106fa01534454a69e471b7bf3d79083981aaab92"}, ] [package.dependencies] @@ -3141,23 +3378,24 @@ files = [ [[package]] name = "openai" -version = "1.39.0" +version = "1.42.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.39.0-py3-none-any.whl", hash = "sha256:a712553a131c59a249c474d0bb6a0414f41df36dc186d3a018fa7e600e57fb7f"}, - {file = "openai-1.39.0.tar.gz", hash = "sha256:0cea446082f50985f26809d704a97749cb366a1ba230ef432c684a9745b3f2d9"}, + {file = "openai-1.42.0-py3-none-any.whl", hash = "sha256:dc91e0307033a4f94931e5d03cc3b29b9717014ad5e73f9f2051b6cb5eda4d80"}, + {file = "openai-1.42.0.tar.gz", hash = "sha256:c9d31853b4e0bc2dc8bd08003b462a006035655a701471695d0bfdc08529cde3"}, ] [package.dependencies] anyio = ">=3.5.0,<5" distro = ">=1.7.0,<2" httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" pydantic = ">=1.9.0,<3" sniffio = "*" tqdm = ">4" -typing-extensions = ">=4.7,<5" +typing-extensions = ">=4.11,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] @@ -3956,13 +4194,13 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pyright" -version = "1.1.374" +version = "1.1.377" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.374-py3-none-any.whl", hash = "sha256:55752bcf7a3646d293cd76710a983b71e16f6128aab2d42468e6eb7e46c0a70d"}, - {file = "pyright-1.1.374.tar.gz", hash = "sha256:d01b2daf864ba5e0362e56b844984865970d7204158e61eb685e2dab7804cb82"}, + {file = "pyright-1.1.377-py3-none-any.whl", hash = "sha256:af0dd2b6b636c383a6569a083f8c5a8748ae4dcde5df7914b3f3f267e14dd162"}, + {file = "pyright-1.1.377.tar.gz", hash = "sha256:aabc30fedce0ded34baa0c49b24f10e68f4bfc8f68ae7f3d175c4b0f256b4fcf"}, ] [package.dependencies] @@ -4116,180 +4354,182 @@ files = [ [[package]] name = "pyyaml" -version = "6.0.1" +version = "6.0.2" description = "YAML parser and emitter for Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, - {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, - {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, - {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, - {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, - {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, - {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, - {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, - {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, - {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, - {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, - {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, - {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, - {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] [[package]] name = "pyzmq" -version = "26.1.0" +version = "26.1.1" description = "Python bindings for 0MQ" optional = false python-versions = ">=3.7" files = [ - {file = "pyzmq-26.1.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:263cf1e36862310bf5becfbc488e18d5d698941858860c5a8c079d1511b3b18e"}, - {file = "pyzmq-26.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d5c8b17f6e8f29138678834cf8518049e740385eb2dbf736e8f07fc6587ec682"}, - {file = "pyzmq-26.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:75a95c2358fcfdef3374cb8baf57f1064d73246d55e41683aaffb6cfe6862917"}, - {file = "pyzmq-26.1.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f99de52b8fbdb2a8f5301ae5fc0f9e6b3ba30d1d5fc0421956967edcc6914242"}, - {file = "pyzmq-26.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bcbfbab4e1895d58ab7da1b5ce9a327764f0366911ba5b95406c9104bceacb0"}, - {file = "pyzmq-26.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:77ce6a332c7e362cb59b63f5edf730e83590d0ab4e59c2aa5bd79419a42e3449"}, - {file = "pyzmq-26.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ba0a31d00e8616149a5ab440d058ec2da621e05d744914774c4dde6837e1f545"}, - {file = "pyzmq-26.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8b88641384e84a258b740801cd4dbc45c75f148ee674bec3149999adda4a8598"}, - {file = "pyzmq-26.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2fa76ebcebe555cce90f16246edc3ad83ab65bb7b3d4ce408cf6bc67740c4f88"}, - {file = "pyzmq-26.1.0-cp310-cp310-win32.whl", hash = "sha256:fbf558551cf415586e91160d69ca6416f3fce0b86175b64e4293644a7416b81b"}, - {file = "pyzmq-26.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:a7b8aab50e5a288c9724d260feae25eda69582be84e97c012c80e1a5e7e03fb2"}, - {file = "pyzmq-26.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:08f74904cb066e1178c1ec706dfdb5c6c680cd7a8ed9efebeac923d84c1f13b1"}, - {file = "pyzmq-26.1.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:46d6800b45015f96b9d92ece229d92f2aef137d82906577d55fadeb9cf5fcb71"}, - {file = "pyzmq-26.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5bc2431167adc50ba42ea3e5e5f5cd70d93e18ab7b2f95e724dd8e1bd2c38120"}, - {file = "pyzmq-26.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3bb34bebaa1b78e562931a1687ff663d298013f78f972a534f36c523311a84d"}, - {file = "pyzmq-26.1.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd3f6329340cef1c7ba9611bd038f2d523cea79f09f9c8f6b0553caba59ec562"}, - {file = "pyzmq-26.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:471880c4c14e5a056a96cd224f5e71211997d40b4bf5e9fdded55dafab1f98f2"}, - {file = "pyzmq-26.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ce6f2b66799971cbae5d6547acefa7231458289e0ad481d0be0740535da38d8b"}, - {file = "pyzmq-26.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0a1f6ea5b1d6cdbb8cfa0536f0d470f12b4b41ad83625012e575f0e3ecfe97f0"}, - {file = "pyzmq-26.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b45e6445ac95ecb7d728604bae6538f40ccf4449b132b5428c09918523abc96d"}, - {file = "pyzmq-26.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:94c4262626424683feea0f3c34951d39d49d354722db2745c42aa6bb50ecd93b"}, - {file = "pyzmq-26.1.0-cp311-cp311-win32.whl", hash = "sha256:a0f0ab9df66eb34d58205913f4540e2ad17a175b05d81b0b7197bc57d000e829"}, - {file = "pyzmq-26.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:8efb782f5a6c450589dbab4cb0f66f3a9026286333fe8f3a084399149af52f29"}, - {file = "pyzmq-26.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:f133d05aaf623519f45e16ab77526e1e70d4e1308e084c2fb4cedb1a0c764bbb"}, - {file = "pyzmq-26.1.0-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:3d3146b1c3dcc8a1539e7cc094700b2be1e605a76f7c8f0979b6d3bde5ad4072"}, - {file = "pyzmq-26.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d9270fbf038bf34ffca4855bcda6e082e2c7f906b9eb8d9a8ce82691166060f7"}, - {file = "pyzmq-26.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:995301f6740a421afc863a713fe62c0aaf564708d4aa057dfdf0f0f56525294b"}, - {file = "pyzmq-26.1.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7eca8b89e56fb8c6c26dd3e09bd41b24789022acf1cf13358e96f1cafd8cae3"}, - {file = "pyzmq-26.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d4feb2e83dfe9ace6374a847e98ee9d1246ebadcc0cb765482e272c34e5820"}, - {file = "pyzmq-26.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:d4fafc2eb5d83f4647331267808c7e0c5722c25a729a614dc2b90479cafa78bd"}, - {file = "pyzmq-26.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:58c33dc0e185dd97a9ac0288b3188d1be12b756eda67490e6ed6a75cf9491d79"}, - {file = "pyzmq-26.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:68a0a1d83d33d8367ddddb3e6bb4afbb0f92bd1dac2c72cd5e5ddc86bdafd3eb"}, - {file = "pyzmq-26.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ae7c57e22ad881af78075e0cea10a4c778e67234adc65c404391b417a4dda83"}, - {file = "pyzmq-26.1.0-cp312-cp312-win32.whl", hash = "sha256:347e84fc88cc4cb646597f6d3a7ea0998f887ee8dc31c08587e9c3fd7b5ccef3"}, - {file = "pyzmq-26.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:9f136a6e964830230912f75b5a116a21fe8e34128dcfd82285aa0ef07cb2c7bd"}, - {file = "pyzmq-26.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:a4b7a989c8f5a72ab1b2bbfa58105578753ae77b71ba33e7383a31ff75a504c4"}, - {file = "pyzmq-26.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d416f2088ac8f12daacffbc2e8918ef4d6be8568e9d7155c83b7cebed49d2322"}, - {file = "pyzmq-26.1.0-cp313-cp313-macosx_10_15_universal2.whl", hash = "sha256:ecb6c88d7946166d783a635efc89f9a1ff11c33d680a20df9657b6902a1d133b"}, - {file = "pyzmq-26.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:471312a7375571857a089342beccc1a63584315188560c7c0da7e0a23afd8a5c"}, - {file = "pyzmq-26.1.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e6cea102ffa16b737d11932c426f1dc14b5938cf7bc12e17269559c458ac334"}, - {file = "pyzmq-26.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec7248673ffc7104b54e4957cee38b2f3075a13442348c8d651777bf41aa45ee"}, - {file = "pyzmq-26.1.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:0614aed6f87d550b5cecb03d795f4ddbb1544b78d02a4bd5eecf644ec98a39f6"}, - {file = "pyzmq-26.1.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:e8746ce968be22a8a1801bf4a23e565f9687088580c3ed07af5846580dd97f76"}, - {file = "pyzmq-26.1.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:7688653574392d2eaeef75ddcd0b2de5b232d8730af29af56c5adf1df9ef8d6f"}, - {file = "pyzmq-26.1.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:8d4dac7d97f15c653a5fedcafa82626bd6cee1450ccdaf84ffed7ea14f2b07a4"}, - {file = "pyzmq-26.1.0-cp313-cp313-win32.whl", hash = "sha256:ccb42ca0a4a46232d716779421bbebbcad23c08d37c980f02cc3a6bd115ad277"}, - {file = "pyzmq-26.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:e1e5d0a25aea8b691a00d6b54b28ac514c8cc0d8646d05f7ca6cb64b97358250"}, - {file = "pyzmq-26.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:fc82269d24860cfa859b676d18850cbb8e312dcd7eada09e7d5b007e2f3d9eb1"}, - {file = "pyzmq-26.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:416ac51cabd54f587995c2b05421324700b22e98d3d0aa2cfaec985524d16f1d"}, - {file = "pyzmq-26.1.0-cp313-cp313t-macosx_10_15_universal2.whl", hash = "sha256:ff832cce719edd11266ca32bc74a626b814fff236824aa1aeaad399b69fe6eae"}, - {file = "pyzmq-26.1.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:393daac1bcf81b2a23e696b7b638eedc965e9e3d2112961a072b6cd8179ad2eb"}, - {file = "pyzmq-26.1.0-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9869fa984c8670c8ab899a719eb7b516860a29bc26300a84d24d8c1b71eae3ec"}, - {file = "pyzmq-26.1.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b3b8e36fd4c32c0825b4461372949ecd1585d326802b1321f8b6dc1d7e9318c"}, - {file = "pyzmq-26.1.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:3ee647d84b83509b7271457bb428cc347037f437ead4b0b6e43b5eba35fec0aa"}, - {file = "pyzmq-26.1.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:45cb1a70eb00405ce3893041099655265fabcd9c4e1e50c330026e82257892c1"}, - {file = "pyzmq-26.1.0-cp313-cp313t-musllinux_1_1_i686.whl", hash = "sha256:5cca7b4adb86d7470e0fc96037771981d740f0b4cb99776d5cb59cd0e6684a73"}, - {file = "pyzmq-26.1.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:91d1a20bdaf3b25f3173ff44e54b1cfbc05f94c9e8133314eb2962a89e05d6e3"}, - {file = "pyzmq-26.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c0665d85535192098420428c779361b8823d3d7ec4848c6af3abb93bc5c915bf"}, - {file = "pyzmq-26.1.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:96d7c1d35ee4a495df56c50c83df7af1c9688cce2e9e0edffdbf50889c167595"}, - {file = "pyzmq-26.1.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b281b5ff5fcc9dcbfe941ac5c7fcd4b6c065adad12d850f95c9d6f23c2652384"}, - {file = "pyzmq-26.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5384c527a9a004445c5074f1e20db83086c8ff1682a626676229aafd9cf9f7d1"}, - {file = "pyzmq-26.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:754c99a9840839375ee251b38ac5964c0f369306eddb56804a073b6efdc0cd88"}, - {file = "pyzmq-26.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:9bdfcb74b469b592972ed881bad57d22e2c0acc89f5e8c146782d0d90fb9f4bf"}, - {file = "pyzmq-26.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:bd13f0231f4788db619347b971ca5f319c5b7ebee151afc7c14632068c6261d3"}, - {file = "pyzmq-26.1.0-cp37-cp37m-win32.whl", hash = "sha256:c5668dac86a869349828db5fc928ee3f58d450dce2c85607067d581f745e4fb1"}, - {file = "pyzmq-26.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ad875277844cfaeca7fe299ddf8c8d8bfe271c3dc1caf14d454faa5cdbf2fa7a"}, - {file = "pyzmq-26.1.0-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:65c6e03cc0222eaf6aad57ff4ecc0a070451e23232bb48db4322cc45602cede0"}, - {file = "pyzmq-26.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:038ae4ffb63e3991f386e7fda85a9baab7d6617fe85b74a8f9cab190d73adb2b"}, - {file = "pyzmq-26.1.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:bdeb2c61611293f64ac1073f4bf6723b67d291905308a7de9bb2ca87464e3273"}, - {file = "pyzmq-26.1.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:61dfa5ee9d7df297c859ac82b1226d8fefaf9c5113dc25c2c00ecad6feeeb04f"}, - {file = "pyzmq-26.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3292d384537b9918010769b82ab3e79fca8b23d74f56fc69a679106a3e2c2cf"}, - {file = "pyzmq-26.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:f9499c70c19ff0fbe1007043acb5ad15c1dec7d8e84ab429bca8c87138e8f85c"}, - {file = "pyzmq-26.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d3dd5523ed258ad58fed7e364c92a9360d1af8a9371e0822bd0146bdf017ef4c"}, - {file = "pyzmq-26.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:baba2fd199b098c5544ef2536b2499d2e2155392973ad32687024bd8572a7d1c"}, - {file = "pyzmq-26.1.0-cp38-cp38-win32.whl", hash = "sha256:ddbb2b386128d8eca92bd9ca74e80f73fe263bcca7aa419f5b4cbc1661e19741"}, - {file = "pyzmq-26.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:79e45a4096ec8388cdeb04a9fa5e9371583bcb826964d55b8b66cbffe7b33c86"}, - {file = "pyzmq-26.1.0-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:add52c78a12196bc0fda2de087ba6c876ea677cbda2e3eba63546b26e8bf177b"}, - {file = "pyzmq-26.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:98c03bd7f3339ff47de7ea9ac94a2b34580a8d4df69b50128bb6669e1191a895"}, - {file = "pyzmq-26.1.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:dcc37d9d708784726fafc9c5e1232de655a009dbf97946f117aefa38d5985a0f"}, - {file = "pyzmq-26.1.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5a6ed52f0b9bf8dcc64cc82cce0607a3dfed1dbb7e8c6f282adfccc7be9781de"}, - {file = "pyzmq-26.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:451e16ae8bea3d95649317b463c9f95cd9022641ec884e3d63fc67841ae86dfe"}, - {file = "pyzmq-26.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:906e532c814e1d579138177a00ae835cd6becbf104d45ed9093a3aaf658f6a6a"}, - {file = "pyzmq-26.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:05bacc4f94af468cc82808ae3293390278d5f3375bb20fef21e2034bb9a505b6"}, - {file = "pyzmq-26.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:57bb2acba798dc3740e913ffadd56b1fcef96f111e66f09e2a8db3050f1f12c8"}, - {file = "pyzmq-26.1.0-cp39-cp39-win32.whl", hash = "sha256:f774841bb0e8588505002962c02da420bcfb4c5056e87a139c6e45e745c0e2e2"}, - {file = "pyzmq-26.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:359c533bedc62c56415a1f5fcfd8279bc93453afdb0803307375ecf81c962402"}, - {file = "pyzmq-26.1.0-cp39-cp39-win_arm64.whl", hash = "sha256:7907419d150b19962138ecec81a17d4892ea440c184949dc29b358bc730caf69"}, - {file = "pyzmq-26.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b24079a14c9596846bf7516fe75d1e2188d4a528364494859106a33d8b48be38"}, - {file = "pyzmq-26.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59d0acd2976e1064f1b398a00e2c3e77ed0a157529779e23087d4c2fb8aaa416"}, - {file = "pyzmq-26.1.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:911c43a4117915203c4cc8755e0f888e16c4676a82f61caee2f21b0c00e5b894"}, - {file = "pyzmq-26.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10163e586cc609f5f85c9b233195554d77b1e9a0801388907441aaeb22841c5"}, - {file = "pyzmq-26.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:28a8b2abb76042f5fd7bd720f7fea48c0fd3e82e9de0a1bf2c0de3812ce44a42"}, - {file = "pyzmq-26.1.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:bef24d3e4ae2c985034439f449e3f9e06bf579974ce0e53d8a507a1577d5b2ab"}, - {file = "pyzmq-26.1.0-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2cd0f4d314f4a2518e8970b6f299ae18cff7c44d4a1fc06fc713f791c3a9e3ea"}, - {file = "pyzmq-26.1.0-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:fa25a620eed2a419acc2cf10135b995f8f0ce78ad00534d729aa761e4adcef8a"}, - {file = "pyzmq-26.1.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef3b048822dca6d231d8a8ba21069844ae38f5d83889b9b690bf17d2acc7d099"}, - {file = "pyzmq-26.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:9a6847c92d9851b59b9f33f968c68e9e441f9a0f8fc972c5580c5cd7cbc6ee24"}, - {file = "pyzmq-26.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c9b9305004d7e4e6a824f4f19b6d8f32b3578aad6f19fc1122aaf320cbe3dc83"}, - {file = "pyzmq-26.1.0-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:63c1d3a65acb2f9c92dce03c4e1758cc552f1ae5c78d79a44e3bb88d2fa71f3a"}, - {file = "pyzmq-26.1.0-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d36b8fffe8b248a1b961c86fbdfa0129dfce878731d169ede7fa2631447331be"}, - {file = "pyzmq-26.1.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67976d12ebfd61a3bc7d77b71a9589b4d61d0422282596cf58c62c3866916544"}, - {file = "pyzmq-26.1.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:998444debc8816b5d8d15f966e42751032d0f4c55300c48cc337f2b3e4f17d03"}, - {file = "pyzmq-26.1.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:e5c88b2f13bcf55fee78ea83567b9fe079ba1a4bef8b35c376043440040f7edb"}, - {file = "pyzmq-26.1.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d906d43e1592be4b25a587b7d96527cb67277542a5611e8ea9e996182fae410"}, - {file = "pyzmq-26.1.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80b0c9942430d731c786545da6be96d824a41a51742e3e374fedd9018ea43106"}, - {file = "pyzmq-26.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:314d11564c00b77f6224d12eb3ddebe926c301e86b648a1835c5b28176c83eab"}, - {file = "pyzmq-26.1.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:093a1a3cae2496233f14b57f4b485da01b4ff764582c854c0f42c6dd2be37f3d"}, - {file = "pyzmq-26.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3c397b1b450f749a7e974d74c06d69bd22dd362142f370ef2bd32a684d6b480c"}, - {file = "pyzmq-26.1.0.tar.gz", hash = "sha256:6c5aeea71f018ebd3b9115c7cb13863dd850e98ca6b9258509de1246461a7e7f"}, + {file = "pyzmq-26.1.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:b1bb952d1e407463c9333ea7e0c0600001e54e08ce836d4f0aff1fb3f902cf63"}, + {file = "pyzmq-26.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:65e2a18e845c6ea7ab849c70db932eaeadee5edede9e379eb21c0a44cf523b2e"}, + {file = "pyzmq-26.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:def7ae3006924b8a0c146a89ab4008310913fa903beedb95e25dea749642528e"}, + {file = "pyzmq-26.1.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a8234571df7816f99dde89c3403cb396d70c6554120b795853a8ea56fcc26cd3"}, + {file = "pyzmq-26.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18da8e84dbc30688fd2baefd41df7190607511f916be34f9a24b0e007551822e"}, + {file = "pyzmq-26.1.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c70dab93d98b2bf3f0ac1265edbf6e7f83acbf71dabcc4611889bb0dea45bed7"}, + {file = "pyzmq-26.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:fcb90592c5d5c562e1b1a1ceccf6f00036d73c51db0271bf4d352b8d6b31d468"}, + {file = "pyzmq-26.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cf4be7460a0c1bc71e9b0e64ecdd75a86386ca6afaa36641686f5542d0314e9d"}, + {file = "pyzmq-26.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4cbecda4ddbfc1e309c3be04d333f9be3fc6178b8b6592b309676f929767a15"}, + {file = "pyzmq-26.1.1-cp310-cp310-win32.whl", hash = "sha256:583f73b113b8165713b6ce028d221402b1b69483055b5aa3f991937e34dd1ead"}, + {file = "pyzmq-26.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:5e6f39ecb8eb7bfcb976c49262e8cf83ff76e082b77ca23ba90c9b6691a345be"}, + {file = "pyzmq-26.1.1-cp310-cp310-win_arm64.whl", hash = "sha256:8d042d6446cab3a1388b38596f5acabb9926b0b95c3894c519356b577a549458"}, + {file = "pyzmq-26.1.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:362cac2423e36966d336d79d3ec3eafeabc153ee3e7a5cf580d7e74a34b3d912"}, + {file = "pyzmq-26.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0841633446cb1539a832a19bb24c03a20c00887d0cedd1d891b495b07e5c5cb5"}, + {file = "pyzmq-26.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e1fcdc333afbf9918d0a614a6e10858aede7da49a60f6705a77e343fe86a317"}, + {file = "pyzmq-26.1.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cc8d655627d775475eafdcf0e49e74bcc1e5e90afd9ab813b4da98f092ed7b93"}, + {file = "pyzmq-26.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32de51744820857a6f7c3077e620ab3f607d0e4388dfead885d5124ab9bcdc5e"}, + {file = "pyzmq-26.1.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a880240597010914ffb1d6edd04d3deb7ce6a2abf79a0012751438d13630a671"}, + {file = "pyzmq-26.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:26131b1cec02f941ed2d2b4b8cc051662b1c248b044eff5069df1f500bbced56"}, + {file = "pyzmq-26.1.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:ce05841322b58510607f9508a573138d995a46c7928887bc433de9cb760fd2ad"}, + {file = "pyzmq-26.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:32123ff0a6db521aadf2b95201e967a4e0d11fb89f73663a99d2f54881c07214"}, + {file = "pyzmq-26.1.1-cp311-cp311-win32.whl", hash = "sha256:e790602d7ea1d6c7d8713d571226d67de7ffe47b1e22ae2c043ebd537de1bccb"}, + {file = "pyzmq-26.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:717960855f2d6fdc2dba9df49dff31c414187bb11c76af36343a57d1f7083d9a"}, + {file = "pyzmq-26.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:08956c26dbcd4fd8835cb777a16e21958ed2412317630e19f0018d49dbeeb470"}, + {file = "pyzmq-26.1.1-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:e80345900ae241c2c51bead7c9fa247bba6d4b2a83423e9791bae8b0a7f12c52"}, + {file = "pyzmq-26.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ec8fe214fcc45dfb0c32e4a7ad1db20244ba2d2fecbf0cbf9d5242d81ca0a375"}, + {file = "pyzmq-26.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf4e283f97688d993cb7a8acbc22889effbbb7cbaa19ee9709751f44be928f5d"}, + {file = "pyzmq-26.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2508bdc8ab246e5ed7c92023d4352aaad63020ca3b098a4e3f1822db202f703d"}, + {file = "pyzmq-26.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:741bdb4d96efe8192616abdc3671931d51a8bcd38c71da2d53fb3127149265d1"}, + {file = "pyzmq-26.1.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:76154943e4c4054b2591792eb3484ef1dd23d59805759f9cebd2f010aa30ee8c"}, + {file = "pyzmq-26.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9498ac427d20d0e0ef0e4bbd6200841e91640dfdf619f544ceec7f464cfb6070"}, + {file = "pyzmq-26.1.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f34453ef3496ca3462f30435bf85f535f9550392987341f9ccc92c102825a79"}, + {file = "pyzmq-26.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:50f0669324e27cc2091ef6ab76ca7112f364b6249691790b4cffce31e73fda28"}, + {file = "pyzmq-26.1.1-cp312-cp312-win32.whl", hash = "sha256:3ee5cbf2625b94de21c68d0cefd35327c8dfdbd6a98fcc41682b4e8bb00d841f"}, + {file = "pyzmq-26.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:75bd448a28b1001b6928679015bc95dd5f172703ed30135bb9e34fc9cda0a3e7"}, + {file = "pyzmq-26.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:4350233569b4bbef88595c5e77ee38995a6f1f1790fae148b578941bfffd1c24"}, + {file = "pyzmq-26.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6c8087a3281c20b1d11042d372ed5a47734af05975d78e4d1d6e7bd1018535f3"}, + {file = "pyzmq-26.1.1-cp313-cp313-macosx_10_15_universal2.whl", hash = "sha256:ebef7d3fe11fe4c688f08bc0211a976c3318c097057f258428200737b9fff4da"}, + {file = "pyzmq-26.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a5342110510045a47de1e87f5f1dcc1d9d90109522316dc9830cfc6157c800f"}, + {file = "pyzmq-26.1.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:af690ea4be6ca92a67c2b44a779a023bf0838e92d48497a2268175dc4a505691"}, + {file = "pyzmq-26.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc994e220c1403ae087d7f0fa45129d583e46668a019e389060da811a5a9320e"}, + {file = "pyzmq-26.1.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:b8e153f5dffb0310af71fc6fc9cd8174f4c8ea312c415adcb815d786fee78179"}, + {file = "pyzmq-26.1.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:0065026e624052a51033857e5cd45a94b52946b44533f965f0bdf182460e965d"}, + {file = "pyzmq-26.1.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:63351392f948b5d50b9f55161994bc4feedbfb3f3cfe393d2f503dea2c3ec445"}, + {file = "pyzmq-26.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ffecc43b3c18e36b62fcec995761829b6ac325d8dd74a4f2c5c1653afbb4495a"}, + {file = "pyzmq-26.1.1-cp313-cp313-win32.whl", hash = "sha256:6ff14c2fae6c0c2c1c02590c5c5d75aa1db35b859971b3ca2fcd28f983d9f2b6"}, + {file = "pyzmq-26.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:85f2d2ee5ea9a8f1de86a300e1062fbab044f45b5ce34d20580c0198a8196db0"}, + {file = "pyzmq-26.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:cc09b1de8b985ca5a0ca343dd7fb007267c6b329347a74e200f4654268084239"}, + {file = "pyzmq-26.1.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:bc904e86de98f8fc5bd41597da5d61232d2d6d60c4397f26efffabb961b2b245"}, + {file = "pyzmq-26.1.1-cp313-cp313t-macosx_10_15_universal2.whl", hash = "sha256:00f39c367bbd6aa8e4bc36af6510561944c619b58eb36199fa334b594a18f615"}, + {file = "pyzmq-26.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de6f384864a959866b782e6a3896538d1424d183f2d3c7ef079f71dcecde7284"}, + {file = "pyzmq-26.1.1-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3abb15df0c763339edb27a644c19381b2425ddd1aea3dbd77c1601a3b31867b8"}, + {file = "pyzmq-26.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40908ec2dd3b29bbadc0916a0d3c87f8dbeebbd8fead8e618539f09e0506dec4"}, + {file = "pyzmq-26.1.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:c11a95d3f6fc7e714ccd1066f68f9c1abd764a8b3596158be92f46dd49f41e03"}, + {file = "pyzmq-26.1.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:4437af9fee7a58302dbd511cc49f0cc2b35c112a33a1111fb123cf0be45205ca"}, + {file = "pyzmq-26.1.1-cp313-cp313t-musllinux_1_1_i686.whl", hash = "sha256:76390d3d66406cb01b9681c382874400e9dfd77f30ecdea4bd1bf5226dd4aff0"}, + {file = "pyzmq-26.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:4d4c7fe5e50e269f9c63a260638488fec194a73993008618a59b54c47ef6ae72"}, + {file = "pyzmq-26.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:25d128524207f53f7aae7c5abdc2b63f8957a060b00521af5ffcd20986b5d8f4"}, + {file = "pyzmq-26.1.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d74b925d997e4f92b042bdd7085cd0a309ee0fd7cb4dc376059bbff6b32ff34f"}, + {file = "pyzmq-26.1.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:732f957441e5b1c65a7509395e6b6cafee9e12df9aa5f4bf92ed266fe0ba70ee"}, + {file = "pyzmq-26.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0a45102ad7ed9f9ddf2bd699cc5df37742cf7301111cba06001b927efecb120"}, + {file = "pyzmq-26.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9f380d5333fc7cd17423f486125dcc073918676e33db70a6a8172b19fc78d23d"}, + {file = "pyzmq-26.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:8eaffcd6bf6a9d00b66a2052a33fa7e6a6575427e9644395f13c3d070f2918dc"}, + {file = "pyzmq-26.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:f1483d4975ae1b387b39bb8e23d1ff32fe5621aa9e4ed3055d05e9c5613fea53"}, + {file = "pyzmq-26.1.1-cp37-cp37m-win32.whl", hash = "sha256:a83653c6bbe5887caea55e49fbd2909c14b73acf43bcc051eb60b2d514bbd46e"}, + {file = "pyzmq-26.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9763a8d3f5f74ef679989b373c37cc22e8d07e56d26439205cb83edb7722357f"}, + {file = "pyzmq-26.1.1-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:2b045647caf620ce0ed6c8fd9fb6a73116f99aceed966b152a5ba1b416d25311"}, + {file = "pyzmq-26.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f66dcb6625c002f209cdc12cae1a1fec926493cd2262efe37dc6b25a30cea863"}, + {file = "pyzmq-26.1.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0cf1d980c969fb9e538f52abd2227f09e015096bc5c3ef7aa26e0d64051c1db8"}, + {file = "pyzmq-26.1.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:443ebf5e261a95ee9725693f2a5a71401f89b89df0e0ea58844b074067aac2f1"}, + {file = "pyzmq-26.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29de77ba1b1877fe7defc1b9140e65cbd35f72a63bc501e56c2eae55bde5fff4"}, + {file = "pyzmq-26.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f6071ec95af145d7b659dae6786871cd85f0acc599286b6f8ba0c74592d83dd"}, + {file = "pyzmq-26.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6f0512fc87629ad968889176bf2165d721cd817401a281504329e2a2ed0ca6a3"}, + {file = "pyzmq-26.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5ccfcf13e80719f6a2d9c0a021d9e47d4550907a29253554be2c09582f6d7963"}, + {file = "pyzmq-26.1.1-cp38-cp38-win32.whl", hash = "sha256:809673947e95752e407aaaaf03f205ee86ebfff9ca51db6d4003dfd87b8428d1"}, + {file = "pyzmq-26.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:62b5180e23e6f581600459cd983473cd723fdc64350f606d21407c99832aaf5f"}, + {file = "pyzmq-26.1.1-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:fe73d7c89d6f803bed122135ff5783364e8cdb479cf6fe2d764a44b6349e7e0f"}, + {file = "pyzmq-26.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db1b7e2b50ef21f398036786da4c153db63203a402396d9f21e08ea61f3f8dba"}, + {file = "pyzmq-26.1.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:7c506a51cb01bb997a3f6440db0d121e5e7a32396e9948b1fdb6a7bfa67243f4"}, + {file = "pyzmq-26.1.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:92eca4f80e8a748d880e55d3cf57ef487692e439f12d5c5a2e1cce84aaa7f6cb"}, + {file = "pyzmq-26.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14bdbae02f72f4716b0ffe7500e9da303d719ddde1f3dcfb4c4f6cc1cf73bb02"}, + {file = "pyzmq-26.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e03be7ed17836c9434cce0668ac1e2cc9143d7169f90f46a0167f6155e176e32"}, + {file = "pyzmq-26.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc5df31e36e4fddd4c8b5c42daee8d54d7b529e898ac984be97bf5517de166a7"}, + {file = "pyzmq-26.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f218179c90a12d660906e04b25a340dd63e9743000ba16232ddaf46888f269da"}, + {file = "pyzmq-26.1.1-cp39-cp39-win32.whl", hash = "sha256:7dfabc180a4da422a4b349c63077347392463a75fa07aa3be96712ed6d42c547"}, + {file = "pyzmq-26.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:c5248e6e0fcbbbc912982e99cdd51c342601f495b0fa5bd667f3bdbdbf3e170f"}, + {file = "pyzmq-26.1.1-cp39-cp39-win_arm64.whl", hash = "sha256:2ae7aa1408778dc74582a1226052b930f9083b54b64d7e6ef6ec0466cfdcdec2"}, + {file = "pyzmq-26.1.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:be3fc2b11c0c384949cf1f01f9a48555039408b0f3e877863b1754225635953e"}, + {file = "pyzmq-26.1.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48dee75c2a9fa4f4a583d4028d564a0453447ee1277a29b07acc3743c092e259"}, + {file = "pyzmq-26.1.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:23f2fe4fb567e8098ebaa7204819658195b10ddd86958a97a6058eed2901eed3"}, + {file = "pyzmq-26.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:472cacd16f627c06d3c8b2d374345ab74446bae913584a6245e2aa935336d929"}, + {file = "pyzmq-26.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:8285b25aa20fcc46f1ca4afbc39fd3d5f2fe4c4bbf7f2c7f907a214e87a70024"}, + {file = "pyzmq-26.1.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2067e63fd9d5c13cfe12624dab0366053e523b37a7a01678ce4321f839398939"}, + {file = "pyzmq-26.1.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:cc109be2ee3638035d276e18eaf66a1e1f44201c0c4bea4ee0c692766bbd3570"}, + {file = "pyzmq-26.1.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d0da97e65ee73261dba70469cc8f63d8da3a8a825337a2e3d246b9e95141cdd0"}, + {file = "pyzmq-26.1.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa79c528706561306938b275f89bb2c6985ce08469c27e5de05bc680df5e826f"}, + {file = "pyzmq-26.1.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:3ddbd851a3a2651fdc5065a2804d50cf2f4b13b1bcd66de8e9e855d0217d4fcd"}, + {file = "pyzmq-26.1.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d3df226ab7464684ae6706e20a5cbab717c3735a7e409b3fa598b754d49f1946"}, + {file = "pyzmq-26.1.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:abad7b897e960d577eb4a0f3f789c1780bc3ffe2e7c27cf317e7c90ad26acf12"}, + {file = "pyzmq-26.1.1-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c513d829a548c2d5c88983167be2b3aa537f6d1191edcdc6fcd8999e18bdd994"}, + {file = "pyzmq-26.1.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70af4c9c991714ef1c65957605a8de42ef0d0620dd5f125953c8e682281bdb80"}, + {file = "pyzmq-26.1.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:8d4234f335b0d0842f7d661d8cd50cbad0729be58f1c4deb85cd96b38fe95025"}, + {file = "pyzmq-26.1.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:2c0fdb7b758e0e1605157e480b00b3a599073068a37091a1c75ec65bf7498645"}, + {file = "pyzmq-26.1.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc657577f057d60dd3642c9f95f28b432889b73143140061f7c1331d02f03df6"}, + {file = "pyzmq-26.1.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3e3b66fe6131b4f33d239f7d4c3bfb2f8532d8644bae3b3da4f3987073edac55"}, + {file = "pyzmq-26.1.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59b57e912feef6951aec8bb03fe0faa5ad5f36962883c72a30a9c965e6d988fd"}, + {file = "pyzmq-26.1.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:146956aec7d947c5afc5e7da0841423d7a53f84fd160fff25e682361dcfb32cb"}, + {file = "pyzmq-26.1.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9521b874fd489495865172f344e46e0159095d1f161858e3fc6e28e43ca15160"}, + {file = "pyzmq-26.1.1.tar.gz", hash = "sha256:a7db05d8b7cd1a8c6610e9e9aa55d525baae7a44a43e18bc3260eb3f92de96c6"}, ] [package.dependencies] @@ -4535,141 +4775,141 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "rpds-py" -version = "0.19.1" +version = "0.20.0" description = "Python bindings to Rust's persistent data structures (rpds)" optional = false python-versions = ">=3.8" files = [ - {file = "rpds_py-0.19.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:aaf71f95b21f9dc708123335df22e5a2fef6307e3e6f9ed773b2e0938cc4d491"}, - {file = "rpds_py-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ca0dda0c5715efe2ab35bb83f813f681ebcd2840d8b1b92bfc6fe3ab382fae4a"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81db2e7282cc0487f500d4db203edc57da81acde9e35f061d69ed983228ffe3b"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1a8dfa125b60ec00c7c9baef945bb04abf8ac772d8ebefd79dae2a5f316d7850"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:271accf41b02687cef26367c775ab220372ee0f4925591c6796e7c148c50cab5"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f9bc4161bd3b970cd6a6fcda70583ad4afd10f2750609fb1f3ca9505050d4ef3"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0cf2a0dbb5987da4bd92a7ca727eadb225581dd9681365beba9accbe5308f7d"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b5e28e56143750808c1c79c70a16519e9bc0a68b623197b96292b21b62d6055c"}, - {file = "rpds_py-0.19.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c7af6f7b80f687b33a4cdb0a785a5d4de1fb027a44c9a049d8eb67d5bfe8a687"}, - {file = "rpds_py-0.19.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e429fc517a1c5e2a70d576077231538a98d59a45dfc552d1ac45a132844e6dfb"}, - {file = "rpds_py-0.19.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d2dbd8f4990d4788cb122f63bf000357533f34860d269c1a8e90ae362090ff3a"}, - {file = "rpds_py-0.19.1-cp310-none-win32.whl", hash = "sha256:e0f9d268b19e8f61bf42a1da48276bcd05f7ab5560311f541d22557f8227b866"}, - {file = "rpds_py-0.19.1-cp310-none-win_amd64.whl", hash = "sha256:df7c841813f6265e636fe548a49664c77af31ddfa0085515326342a751a6ba51"}, - {file = "rpds_py-0.19.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:902cf4739458852fe917104365ec0efbea7d29a15e4276c96a8d33e6ed8ec137"}, - {file = "rpds_py-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f3d73022990ab0c8b172cce57c69fd9a89c24fd473a5e79cbce92df87e3d9c48"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3837c63dd6918a24de6c526277910e3766d8c2b1627c500b155f3eecad8fad65"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cdb7eb3cf3deb3dd9e7b8749323b5d970052711f9e1e9f36364163627f96da58"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26ab43b6d65d25b1a333c8d1b1c2f8399385ff683a35ab5e274ba7b8bb7dc61c"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:75130df05aae7a7ac171b3b5b24714cffeabd054ad2ebc18870b3aa4526eba23"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c34f751bf67cab69638564eee34023909380ba3e0d8ee7f6fe473079bf93f09b"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f2671cb47e50a97f419a02cd1e0c339b31de017b033186358db92f4d8e2e17d8"}, - {file = "rpds_py-0.19.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c73254c256081704dba0a333457e2fb815364018788f9b501efe7c5e0ada401"}, - {file = "rpds_py-0.19.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4383beb4a29935b8fa28aca8fa84c956bf545cb0c46307b091b8d312a9150e6a"}, - {file = "rpds_py-0.19.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dbceedcf4a9329cc665452db1aaf0845b85c666e4885b92ee0cddb1dbf7e052a"}, - {file = "rpds_py-0.19.1-cp311-none-win32.whl", hash = "sha256:f0a6d4a93d2a05daec7cb885157c97bbb0be4da739d6f9dfb02e101eb40921cd"}, - {file = "rpds_py-0.19.1-cp311-none-win_amd64.whl", hash = "sha256:c149a652aeac4902ecff2dd93c3b2681c608bd5208c793c4a99404b3e1afc87c"}, - {file = "rpds_py-0.19.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:56313be667a837ff1ea3508cebb1ef6681d418fa2913a0635386cf29cff35165"}, - {file = "rpds_py-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d1d7539043b2b31307f2c6c72957a97c839a88b2629a348ebabe5aa8b626d6b"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e1dc59a5e7bc7f44bd0c048681f5e05356e479c50be4f2c1a7089103f1621d5"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8f78398e67a7227aefa95f876481485403eb974b29e9dc38b307bb6eb2315ea"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ef07a0a1d254eeb16455d839cef6e8c2ed127f47f014bbda64a58b5482b6c836"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8124101e92c56827bebef084ff106e8ea11c743256149a95b9fd860d3a4f331f"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08ce9c95a0b093b7aec75676b356a27879901488abc27e9d029273d280438505"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0b02dd77a2de6e49078c8937aadabe933ceac04b41c5dde5eca13a69f3cf144e"}, - {file = "rpds_py-0.19.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4dd02e29c8cbed21a1875330b07246b71121a1c08e29f0ee3db5b4cfe16980c4"}, - {file = "rpds_py-0.19.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9c7042488165f7251dc7894cd533a875d2875af6d3b0e09eda9c4b334627ad1c"}, - {file = "rpds_py-0.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f809a17cc78bd331e137caa25262b507225854073fd319e987bd216bed911b7c"}, - {file = "rpds_py-0.19.1-cp312-none-win32.whl", hash = "sha256:3ddab996807c6b4227967fe1587febade4e48ac47bb0e2d3e7858bc621b1cace"}, - {file = "rpds_py-0.19.1-cp312-none-win_amd64.whl", hash = "sha256:32e0db3d6e4f45601b58e4ac75c6f24afbf99818c647cc2066f3e4b192dabb1f"}, - {file = "rpds_py-0.19.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:747251e428406b05fc86fee3904ee19550c4d2d19258cef274e2151f31ae9d38"}, - {file = "rpds_py-0.19.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:dc733d35f861f8d78abfaf54035461e10423422999b360966bf1c443cbc42705"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbda75f245caecff8faa7e32ee94dfaa8312a3367397975527f29654cd17a6ed"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bd04d8cab16cab5b0a9ffc7d10f0779cf1120ab16c3925404428f74a0a43205a"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2d66eb41ffca6cc3c91d8387509d27ba73ad28371ef90255c50cb51f8953301"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fdf4890cda3b59170009d012fca3294c00140e7f2abe1910e6a730809d0f3f9b"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1fa67ef839bad3815124f5f57e48cd50ff392f4911a9f3cf449d66fa3df62a5"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b82c9514c6d74b89a370c4060bdb80d2299bc6857e462e4a215b4ef7aa7b090e"}, - {file = "rpds_py-0.19.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c7b07959866a6afb019abb9564d8a55046feb7a84506c74a6f197cbcdf8a208e"}, - {file = "rpds_py-0.19.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4f580ae79d0b861dfd912494ab9d477bea535bfb4756a2269130b6607a21802e"}, - {file = "rpds_py-0.19.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c6d20c8896c00775e6f62d8373aba32956aa0b850d02b5ec493f486c88e12859"}, - {file = "rpds_py-0.19.1-cp313-none-win32.whl", hash = "sha256:afedc35fe4b9e30ab240b208bb9dc8938cb4afe9187589e8d8d085e1aacb8309"}, - {file = "rpds_py-0.19.1-cp313-none-win_amd64.whl", hash = "sha256:1d4af2eb520d759f48f1073ad3caef997d1bfd910dc34e41261a595d3f038a94"}, - {file = "rpds_py-0.19.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:34bca66e2e3eabc8a19e9afe0d3e77789733c702c7c43cd008e953d5d1463fde"}, - {file = "rpds_py-0.19.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:24f8ae92c7fae7c28d0fae9b52829235df83f34847aa8160a47eb229d9666c7b"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71157f9db7f6bc6599a852852f3389343bea34315b4e6f109e5cbc97c1fb2963"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1d494887d40dc4dd0d5a71e9d07324e5c09c4383d93942d391727e7a40ff810b"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b3661e6d4ba63a094138032c1356d557de5b3ea6fd3cca62a195f623e381c76"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97fbb77eaeb97591efdc654b8b5f3ccc066406ccfb3175b41382f221ecc216e8"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4cc4bc73e53af8e7a42c8fd7923bbe35babacfa7394ae9240b3430b5dcf16b2a"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:35af5e4d5448fa179fd7fff0bba0fba51f876cd55212f96c8bbcecc5c684ae5c"}, - {file = "rpds_py-0.19.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3511f6baf8438326e351097cecd137eb45c5f019944fe0fd0ae2fea2fd26be39"}, - {file = "rpds_py-0.19.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:57863d16187995c10fe9cf911b897ed443ac68189179541734502353af33e693"}, - {file = "rpds_py-0.19.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:9e318e6786b1e750a62f90c6f7fa8b542102bdcf97c7c4de2a48b50b61bd36ec"}, - {file = "rpds_py-0.19.1-cp38-none-win32.whl", hash = "sha256:53dbc35808c6faa2ce3e48571f8f74ef70802218554884787b86a30947842a14"}, - {file = "rpds_py-0.19.1-cp38-none-win_amd64.whl", hash = "sha256:8df1c283e57c9cb4d271fdc1875f4a58a143a2d1698eb0d6b7c0d7d5f49c53a1"}, - {file = "rpds_py-0.19.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e76c902d229a3aa9d5ceb813e1cbcc69bf5bda44c80d574ff1ac1fa3136dea71"}, - {file = "rpds_py-0.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:de1f7cd5b6b351e1afd7568bdab94934d656abe273d66cda0ceea43bbc02a0c2"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24fc5a84777cb61692d17988989690d6f34f7f95968ac81398d67c0d0994a897"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:74129d5ffc4cde992d89d345f7f7d6758320e5d44a369d74d83493429dad2de5"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5e360188b72f8080fefa3adfdcf3618604cc8173651c9754f189fece068d2a45"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13e6d4840897d4e4e6b2aa1443e3a8eca92b0402182aafc5f4ca1f5e24f9270a"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f09529d2332264a902688031a83c19de8fda5eb5881e44233286b9c9ec91856d"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0d4b52811dcbc1aba08fd88d475f75b4f6db0984ba12275d9bed1a04b2cae9b5"}, - {file = "rpds_py-0.19.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dd635c2c4043222d80d80ca1ac4530a633102a9f2ad12252183bcf338c1b9474"}, - {file = "rpds_py-0.19.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f35b34a5184d5e0cc360b61664c1c06e866aab077b5a7c538a3e20c8fcdbf90b"}, - {file = "rpds_py-0.19.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d4ec0046facab83012d821b33cead742a35b54575c4edfb7ed7445f63441835f"}, - {file = "rpds_py-0.19.1-cp39-none-win32.whl", hash = "sha256:f5b8353ea1a4d7dfb59a7f45c04df66ecfd363bb5b35f33b11ea579111d4655f"}, - {file = "rpds_py-0.19.1-cp39-none-win_amd64.whl", hash = "sha256:1fb93d3486f793d54a094e2bfd9cd97031f63fcb5bc18faeb3dd4b49a1c06523"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7d5c7e32f3ee42f77d8ff1a10384b5cdcc2d37035e2e3320ded909aa192d32c3"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:89cc8921a4a5028d6dd388c399fcd2eef232e7040345af3d5b16c04b91cf3c7e"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bca34e913d27401bda2a6f390d0614049f5a95b3b11cd8eff80fe4ec340a1208"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5953391af1405f968eb5701ebbb577ebc5ced8d0041406f9052638bafe52209d"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:840e18c38098221ea6201f091fc5d4de6128961d2930fbbc96806fb43f69aec1"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6d8b735c4d162dc7d86a9cf3d717f14b6c73637a1f9cd57fe7e61002d9cb1972"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce757c7c90d35719b38fa3d4ca55654a76a40716ee299b0865f2de21c146801c"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a9421b23c85f361a133aa7c5e8ec757668f70343f4ed8fdb5a4a14abd5437244"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:3b823be829407393d84ee56dc849dbe3b31b6a326f388e171555b262e8456cc1"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:5e58b61dcbb483a442c6239c3836696b79f2cd8e7eec11e12155d3f6f2d886d1"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39d67896f7235b2c886fb1ee77b1491b77049dcef6fbf0f401e7b4cbed86bbd4"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:8b32cd4ab6db50c875001ba4f5a6b30c0f42151aa1fbf9c2e7e3674893fb1dc4"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:1c32e41de995f39b6b315d66c27dea3ef7f7c937c06caab4c6a79a5e09e2c415"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a129c02b42d46758c87faeea21a9f574e1c858b9f358b6dd0bbd71d17713175"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:346557f5b1d8fd9966059b7a748fd79ac59f5752cd0e9498d6a40e3ac1c1875f"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:31e450840f2f27699d014cfc8865cc747184286b26d945bcea6042bb6aa4d26e"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:01227f8b3e6c8961490d869aa65c99653df80d2f0a7fde8c64ebddab2b9b02fd"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:69084fd29bfeff14816666c93a466e85414fe6b7d236cfc108a9c11afa6f7301"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d2b88efe65544a7d5121b0c3b003ebba92bfede2ea3577ce548b69c5235185"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ea961a674172ed2235d990d7edf85d15d8dfa23ab8575e48306371c070cda67"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:5beffdbe766cfe4fb04f30644d822a1080b5359df7db3a63d30fa928375b2720"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:720f3108fb1bfa32e51db58b832898372eb5891e8472a8093008010911e324c5"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:c2087dbb76a87ec2c619253e021e4fb20d1a72580feeaa6892b0b3d955175a71"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2ddd50f18ebc05ec29a0d9271e9dbe93997536da3546677f8ca00b76d477680c"}, - {file = "rpds_py-0.19.1.tar.gz", hash = "sha256:31dd5794837f00b46f4096aa8ccaa5972f73a938982e32ed817bb520c465e520"}, + {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, + {file = "rpds_py-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9bb4a0d90fdb03437c109a17eade42dfbf6190408f29b2744114d11586611d6f"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6377e647bbfd0a0b159fe557f2c6c602c159fc752fa316572f012fc0bf67150"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb851b7df9dda52dc1415ebee12362047ce771fc36914586b2e9fcbd7d293b3e"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e0f80b739e5a8f54837be5d5c924483996b603d5502bfff79bf33da06164ee2"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a8c94dad2e45324fc74dce25e1645d4d14df9a4e54a30fa0ae8bad9a63928e3"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e604fe73ba048c06085beaf51147eaec7df856824bfe7b98657cf436623daf"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:df3de6b7726b52966edf29663e57306b23ef775faf0ac01a3e9f4012a24a4140"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf258ede5bc22a45c8e726b29835b9303c285ab46fc7c3a4cc770736b5304c9f"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:55fea87029cded5df854ca7e192ec7bdb7ecd1d9a3f63d5c4eb09148acf4a7ce"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae94bd0b2f02c28e199e9bc51485d0c5601f58780636185660f86bf80c89af94"}, + {file = "rpds_py-0.20.0-cp310-none-win32.whl", hash = "sha256:28527c685f237c05445efec62426d285e47a58fb05ba0090a4340b73ecda6dee"}, + {file = "rpds_py-0.20.0-cp310-none-win_amd64.whl", hash = "sha256:238a2d5b1cad28cdc6ed15faf93a998336eb041c4e440dd7f902528b8891b399"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac2f4f7a98934c2ed6505aead07b979e6f999389f16b714448fb39bbaa86a489"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:220002c1b846db9afd83371d08d239fdc865e8f8c5795bbaec20916a76db3318"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d7919548df3f25374a1f5d01fbcd38dacab338ef5f33e044744b5c36729c8db"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:758406267907b3781beee0f0edfe4a179fbd97c0be2e9b1154d7f0a1279cf8e5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d61339e9f84a3f0767b1995adfb171a0d00a1185192718a17af6e124728e0f5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1259c7b3705ac0a0bd38197565a5d603218591d3f6cee6e614e380b6ba61c6f6"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1dc0f53856b9cc9a0ccca0a7cc61d3d20a7088201c0937f3f4048c1718a209"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7e60cb630f674a31f0368ed32b2a6b4331b8350d67de53c0359992444b116dd3"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbe982f38565bb50cb7fb061ebf762c2f254ca3d8c20d4006878766e84266272"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:514b3293b64187172bc77c8fb0cdae26981618021053b30d8371c3a902d4d5ad"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d0a26ffe9d4dd35e4dfdd1e71f46401cff0181c75ac174711ccff0459135fa58"}, + {file = "rpds_py-0.20.0-cp311-none-win32.whl", hash = "sha256:89c19a494bf3ad08c1da49445cc5d13d8fefc265f48ee7e7556839acdacf69d0"}, + {file = "rpds_py-0.20.0-cp311-none-win_amd64.whl", hash = "sha256:c638144ce971df84650d3ed0096e2ae7af8e62ecbbb7b201c8935c370df00a2c"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a84ab91cbe7aab97f7446652d0ed37d35b68a465aeef8fc41932a9d7eee2c1a6"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:56e27147a5a4c2c21633ff8475d185734c0e4befd1c989b5b95a5d0db699b21b"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2580b0c34583b85efec8c5c5ec9edf2dfe817330cc882ee972ae650e7b5ef739"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b80d4a7900cf6b66bb9cee5c352b2d708e29e5a37fe9bf784fa97fc11504bf6c"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50eccbf054e62a7b2209b28dc7a22d6254860209d6753e6b78cfaeb0075d7bee"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:49a8063ea4296b3a7e81a5dfb8f7b2d73f0b1c20c2af401fb0cdf22e14711a96"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea438162a9fcbee3ecf36c23e6c68237479f89f962f82dae83dc15feeceb37e4"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18d7585c463087bddcfa74c2ba267339f14f2515158ac4db30b1f9cbdb62c8ef"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4c7d1a051eeb39f5c9547e82ea27cbcc28338482242e3e0b7768033cb083821"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4df1e3b3bec320790f699890d41c59d250f6beda159ea3c44c3f5bac1976940"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2cf126d33a91ee6eedc7f3197b53e87a2acdac63602c0f03a02dd69e4b138174"}, + {file = "rpds_py-0.20.0-cp312-none-win32.whl", hash = "sha256:8bc7690f7caee50b04a79bf017a8d020c1f48c2a1077ffe172abec59870f1139"}, + {file = "rpds_py-0.20.0-cp312-none-win_amd64.whl", hash = "sha256:0e13e6952ef264c40587d510ad676a988df19adea20444c2b295e536457bc585"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa9a0521aeca7d4941499a73ad7d4f8ffa3d1affc50b9ea11d992cd7eff18a29"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1f1d51eccb7e6c32ae89243cb352389228ea62f89cd80823ea7dd1b98e0b91"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a86a9b96070674fc88b6f9f71a97d2c1d3e5165574615d1f9168ecba4cecb24"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c8ef2ebf76df43f5750b46851ed1cdf8f109d7787ca40035fe19fbdc1acc5a7"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b25f024b421d5859d156750ea9a65651793d51b76a2e9238c05c9d5f203a9"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57eb94a8c16ab08fef6404301c38318e2c5a32216bf5de453e2714c964c125c8"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1940dae14e715e2e02dfd5b0f64a52e8374a517a1e531ad9412319dc3ac7879"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d20277fd62e1b992a50c43f13fbe13277a31f8c9f70d59759c88f644d66c619f"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:06db23d43f26478303e954c34c75182356ca9aa7797d22c5345b16871ab9c45c"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2a5db5397d82fa847e4c624b0c98fe59d2d9b7cf0ce6de09e4d2e80f8f5b3f2"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a35df9f5548fd79cb2f52d27182108c3e6641a4feb0f39067911bf2adaa3e57"}, + {file = "rpds_py-0.20.0-cp313-none-win32.whl", hash = "sha256:fd2d84f40633bc475ef2d5490b9c19543fbf18596dcb1b291e3a12ea5d722f7a"}, + {file = "rpds_py-0.20.0-cp313-none-win_amd64.whl", hash = "sha256:9bc2d153989e3216b0559251b0c260cfd168ec78b1fac33dd485750a228db5a2"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f2fbf7db2012d4876fb0d66b5b9ba6591197b0f165db8d99371d976546472a24"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e5f3cd7397c8f86c8cc72d5a791071431c108edd79872cdd96e00abd8497d29"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce9845054c13696f7af7f2b353e6b4f676dab1b4b215d7fe5e05c6f8bb06f965"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c3e130fd0ec56cb76eb49ef52faead8ff09d13f4527e9b0c400307ff72b408e1"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b16aa0107ecb512b568244ef461f27697164d9a68d8b35090e9b0c1c8b27752"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7f429242aae2947246587d2964fad750b79e8c233a2367f71b554e9447949c"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af0fc424a5842a11e28956e69395fbbeab2c97c42253169d87e90aac2886d751"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8c00a3b1e70c1d3891f0db1b05292747f0dbcfb49c43f9244d04c70fbc40eb8"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:40ce74fc86ee4645d0a225498d091d8bc61f39b709ebef8204cb8b5a464d3c0e"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4fe84294c7019456e56d93e8ababdad5a329cd25975be749c3f5f558abb48253"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:338ca4539aad4ce70a656e5187a3a31c5204f261aef9f6ab50e50bcdffaf050a"}, + {file = "rpds_py-0.20.0-cp38-none-win32.whl", hash = "sha256:54b43a2b07db18314669092bb2de584524d1ef414588780261e31e85846c26a5"}, + {file = "rpds_py-0.20.0-cp38-none-win_amd64.whl", hash = "sha256:a1862d2d7ce1674cffa6d186d53ca95c6e17ed2b06b3f4c476173565c862d232"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3fde368e9140312b6e8b6c09fb9f8c8c2f00999d1823403ae90cc00480221b22"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9824fb430c9cf9af743cf7aaf6707bf14323fb51ee74425c380f4c846ea70789"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11ef6ce74616342888b69878d45e9f779b95d4bd48b382a229fe624a409b72c5"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52d3f2f82b763a24ef52f5d24358553e8403ce05f893b5347098014f2d9eff2"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d35cef91e59ebbeaa45214861874bc6f19eb35de96db73e467a8358d701a96c"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72278a30111e5b5525c1dd96120d9e958464316f55adb030433ea905866f4de"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c29cbbba378759ac5786730d1c3cb4ec6f8ababf5c42a9ce303dc4b3d08cda"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6632f2d04f15d1bd6fe0eedd3b86d9061b836ddca4c03d5cf5c7e9e6b7c14580"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d0b67d87bb45ed1cd020e8fbf2307d449b68abc45402fe1a4ac9e46c3c8b192b"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec31a99ca63bf3cd7f1a5ac9fe95c5e2d060d3c768a09bc1d16e235840861420"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22e6c9976e38f4d8c4a63bd8a8edac5307dffd3ee7e6026d97f3cc3a2dc02a0b"}, + {file = "rpds_py-0.20.0-cp39-none-win32.whl", hash = "sha256:569b3ea770c2717b730b61998b6c54996adee3cef69fc28d444f3e7920313cf7"}, + {file = "rpds_py-0.20.0-cp39-none-win_amd64.whl", hash = "sha256:e6900ecdd50ce0facf703f7a00df12374b74bbc8ad9fe0f6559947fb20f82364"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:617c7357272c67696fd052811e352ac54ed1d9b49ab370261a80d3b6ce385045"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9426133526f69fcaba6e42146b4e12d6bc6c839b8b555097020e2b78ce908dcc"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deb62214c42a261cb3eb04d474f7155279c1a8a8c30ac89b7dcb1721d92c3c02"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcaeb7b57f1a1e071ebd748984359fef83ecb026325b9d4ca847c95bc7311c92"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d454b8749b4bd70dd0a79f428731ee263fa6995f83ccb8bada706e8d1d3ff89d"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d807dc2051abe041b6649681dce568f8e10668e3c1c6543ebae58f2d7e617855"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c20f0ddeb6e29126d45f89206b8291352b8c5b44384e78a6499d68b52ae511"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7f19250ceef892adf27f0399b9e5afad019288e9be756d6919cb58892129f51"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4f1ed4749a08379555cebf4650453f14452eaa9c43d0a95c49db50c18b7da075"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dcedf0b42bcb4cfff4101d7771a10532415a6106062f005ab97d1d0ab5681c60"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39ed0d010457a78f54090fafb5d108501b5aa5604cc22408fc1c0c77eac14344"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f918a1a130a6dfe1d7fe0f105064141342e7dd1611f2e6a21cd2f5c8cb1cfb3e"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f60012a73aa396be721558caa3a6fd49b3dd0033d1675c6d59c4502e870fcf0c"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d2b1ad682a3dfda2a4e8ad8572f3100f95fad98cb99faf37ff0ddfe9cbf9d03"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:614fdafe9f5f19c63ea02817fa4861c606a59a604a77c8cdef5aa01d28b97921"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa518bcd7600c584bf42e6617ee8132869e877db2f76bcdc281ec6a4113a53ab"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0475242f447cc6cb8a9dd486d68b2ef7fbee84427124c232bff5f63b1fe11e5"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90a4cd061914a60bd51c68bcb4357086991bd0bb93d8aa66a6da7701370708f"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:def7400461c3a3f26e49078302e1c1b38f6752342c77e3cf72ce91ca69fb1bc1"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:65794e4048ee837494aea3c21a28ad5fc080994dfba5b036cf84de37f7ad5074"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:faefcc78f53a88f3076b7f8be0a8f8d35133a3ecf7f3770895c25f8813460f08"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:5b4f105deeffa28bbcdff6c49b34e74903139afa690e35d2d9e3c2c2fba18cec"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdfc3a892927458d98f3d55428ae46b921d1f7543b89382fdb483f5640daaec8"}, + {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, ] [[package]] name = "ruff" -version = "0.5.6" +version = "0.5.7" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.5.6-py3-none-linux_armv6l.whl", hash = "sha256:a0ef5930799a05522985b9cec8290b185952f3fcd86c1772c3bdbd732667fdcd"}, - {file = "ruff-0.5.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b652dc14f6ef5d1552821e006f747802cc32d98d5509349e168f6bf0ee9f8f42"}, - {file = "ruff-0.5.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:80521b88d26a45e871f31e4b88938fd87db7011bb961d8afd2664982dfc3641a"}, - {file = "ruff-0.5.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9bc8f328a9f1309ae80e4d392836e7dbc77303b38ed4a7112699e63d3b066ab"}, - {file = "ruff-0.5.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4d394940f61f7720ad371ddedf14722ee1d6250fd8d020f5ea5a86e7be217daf"}, - {file = "ruff-0.5.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:111a99cdb02f69ddb2571e2756e017a1496c2c3a2aeefe7b988ddab38b416d36"}, - {file = "ruff-0.5.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e395daba77a79f6dc0d07311f94cc0560375ca20c06f354c7c99af3bf4560c5d"}, - {file = "ruff-0.5.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c476acb43c3c51e3c614a2e878ee1589655fa02dab19fe2db0423a06d6a5b1b6"}, - {file = "ruff-0.5.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e2ff8003f5252fd68425fd53d27c1f08b201d7ed714bb31a55c9ac1d4c13e2eb"}, - {file = "ruff-0.5.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c94e084ba3eaa80c2172918c2ca2eb2230c3f15925f4ed8b6297260c6ef179ad"}, - {file = "ruff-0.5.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1f77c1c3aa0669fb230b06fb24ffa3e879391a3ba3f15e3d633a752da5a3e670"}, - {file = "ruff-0.5.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:f908148c93c02873210a52cad75a6eda856b2cbb72250370ce3afef6fb99b1ed"}, - {file = "ruff-0.5.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:563a7ae61ad284187d3071d9041c08019975693ff655438d8d4be26e492760bd"}, - {file = "ruff-0.5.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:94fe60869bfbf0521e04fd62b74cbca21cbc5beb67cbb75ab33fe8c174f54414"}, - {file = "ruff-0.5.6-py3-none-win32.whl", hash = "sha256:e6a584c1de6f8591c2570e171cc7ce482bb983d49c70ddf014393cd39e9dfaed"}, - {file = "ruff-0.5.6-py3-none-win_amd64.whl", hash = "sha256:d7fe7dccb1a89dc66785d7aa0ac283b2269712d8ed19c63af908fdccca5ccc1a"}, - {file = "ruff-0.5.6-py3-none-win_arm64.whl", hash = "sha256:57c6c0dd997b31b536bff49b9eee5ed3194d60605a4427f735eeb1f9c1b8d264"}, - {file = "ruff-0.5.6.tar.gz", hash = "sha256:07c9e3c2a8e1fe377dd460371c3462671a728c981c3205a5217291422209f642"}, + {file = "ruff-0.5.7-py3-none-linux_armv6l.whl", hash = "sha256:548992d342fc404ee2e15a242cdbea4f8e39a52f2e7752d0e4cbe88d2d2f416a"}, + {file = "ruff-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00cc8872331055ee017c4f1071a8a31ca0809ccc0657da1d154a1d2abac5c0be"}, + {file = "ruff-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf3d86a1fdac1aec8a3417a63587d93f906c678bb9ed0b796da7b59c1114a1e"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a01c34400097b06cf8a6e61b35d6d456d5bd1ae6961542de18ec81eaf33b4cb8"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcc8054f1a717e2213500edaddcf1dbb0abad40d98e1bd9d0ad364f75c763eea"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f70284e73f36558ef51602254451e50dd6cc479f8b6f8413a95fcb5db4a55fc"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:a78ad870ae3c460394fc95437d43deb5c04b5c29297815a2a1de028903f19692"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ccd078c66a8e419475174bfe60a69adb36ce04f8d4e91b006f1329d5cd44bcf"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e31c9bad4ebf8fdb77b59cae75814440731060a09a0e0077d559a556453acbb"}, + {file = "ruff-0.5.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d796327eed8e168164346b769dd9a27a70e0298d667b4ecee6877ce8095ec8e"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a09ea2c3f7778cc635e7f6edf57d566a8ee8f485f3c4454db7771efb692c499"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a36d8dcf55b3a3bc353270d544fb170d75d2dff41eba5df57b4e0b67a95bb64e"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9369c218f789eefbd1b8d82a8cf25017b523ac47d96b2f531eba73770971c9e5"}, + {file = "ruff-0.5.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b88ca3db7eb377eb24fb7c82840546fb7acef75af4a74bd36e9ceb37a890257e"}, + {file = "ruff-0.5.7-py3-none-win32.whl", hash = "sha256:33d61fc0e902198a3e55719f4be6b375b28f860b09c281e4bdbf783c0566576a"}, + {file = "ruff-0.5.7-py3-none-win_amd64.whl", hash = "sha256:083bbcbe6fadb93cd86709037acc510f86eed5a314203079df174c40bbbca6b3"}, + {file = "ruff-0.5.7-py3-none-win_arm64.whl", hash = "sha256:2dca26154ff9571995107221d0aeaad0e75a77b5a682d6236cf89a58c70b76f4"}, + {file = "ruff-0.5.7.tar.gz", hash = "sha256:8dfc0a458797f5d9fb622dd0efc52d796f23f0a1493a9527f4e49a550ae9a7e5"}, ] [[package]] @@ -4814,19 +5054,19 @@ win32 = ["pywin32"] [[package]] name = "setuptools" -version = "72.1.0" +version = "73.0.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-72.1.0-py3-none-any.whl", hash = "sha256:5a03e1860cf56bb6ef48ce186b0e557fdba433237481a9a625176c2831be15d1"}, - {file = "setuptools-72.1.0.tar.gz", hash = "sha256:8d243eff56d095e5817f796ede6ae32941278f542e0f941867cc05ae52b162ec"}, + {file = "setuptools-73.0.1-py3-none-any.whl", hash = "sha256:b208925fcb9f7af924ed2dc04708ea89791e24bde0d3020b27df0e116088b34e"}, + {file = "setuptools-73.0.1.tar.gz", hash = "sha256:d59a3e788ab7e012ab2c4baed1b376da6366883ee20d7a5fc426816e3d7b1193"}, ] [package.extras] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] [[package]] name = "six" @@ -4877,13 +5117,13 @@ files = [ [[package]] name = "soupsieve" -version = "2.5" +version = "2.6" description = "A modern CSS selector implementation for Beautiful Soup." optional = false python-versions = ">=3.8" files = [ - {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, - {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, + {file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"}, + {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, ] [[package]] @@ -5196,13 +5436,13 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "types-python-dateutil" -version = "2.9.0.20240316" +version = "2.9.0.20240821" description = "Typing stubs for python-dateutil" optional = false python-versions = ">=3.8" files = [ - {file = "types-python-dateutil-2.9.0.20240316.tar.gz", hash = "sha256:5d2f2e240b86905e40944dd787db6da9263f0deabef1076ddaed797351ec0202"}, - {file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"}, + {file = "types-python-dateutil-2.9.0.20240821.tar.gz", hash = "sha256:9649d1dcb6fef1046fb18bebe9ea2aa0028b160918518c34589a46045f6ebd98"}, + {file = "types_python_dateutil-2.9.0.20240821-py3-none-any.whl", hash = "sha256:f5889fcb4e63ed4aaa379b44f93c32593d50b9a94c9a60a0c854d8cc3511cd57"}, ] [[package]] @@ -5367,13 +5607,13 @@ files = [ [[package]] name = "webcolors" -version = "24.6.0" +version = "24.8.0" description = "A library for working with the color formats defined by HTML and CSS." optional = false python-versions = ">=3.8" files = [ - {file = "webcolors-24.6.0-py3-none-any.whl", hash = "sha256:8cf5bc7e28defd1d48b9e83d5fc30741328305a8195c29a8e668fa45586568a1"}, - {file = "webcolors-24.6.0.tar.gz", hash = "sha256:1d160d1de46b3e81e58d0a280d0c78b467dc80f47294b91b1ad8029d2cedb55b"}, + {file = "webcolors-24.8.0-py3-none-any.whl", hash = "sha256:fc4c3b59358ada164552084a8ebee637c221e4059267d0f8325b3b560f6c7f0a"}, + {file = "webcolors-24.8.0.tar.gz", hash = "sha256:08b07af286a01bcd30d583a7acadf629583d1f79bfef27dd2c2c5c263817277d"}, ] [package.extras] @@ -5602,13 +5842,13 @@ multidict = ">=4.0" [[package]] name = "zipp" -version = "3.19.2" +version = "3.20.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"}, - {file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"}, + {file = "zipp-3.20.0-py3-none-any.whl", hash = "sha256:58da6168be89f0be59beb194da1250516fdaa062ccebd30127ac65d30045e10d"}, + {file = "zipp-3.20.0.tar.gz", hash = "sha256:0145e43d89664cfe1a2e533adc75adafed82fe2da404b4bbb6b026c0157bdb31"}, ] [package.extras] @@ -5618,4 +5858,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "b2c972b8acfe25a9da7a68b3c2694a431908d66738b8082d48bfe124a78585b9" +content-hash = "f06dbe201e3dfea982b0b052a3d6811e1be7acde113f3276f61390bd80684447" From 16c4759f436434baf3c41239b78919ca1759a60d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Thu, 22 Aug 2024 10:17:53 -0700 Subject: [PATCH 22/87] Correct file path in text files --- graphrag/index/input/text.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/graphrag/index/input/text.py b/graphrag/index/input/text.py index fe5ae25f8e..c9609d5673 100644 --- a/graphrag/index/input/text.py +++ b/graphrag/index/input/text.py @@ -34,15 +34,15 @@ async def load_file( ) -> dict[str, Any]: if group is None: group = {} - text = await storage.get(path, encoding="utf-8") + text = await storage.get(path.replace('\\','/'), encoding="utf-8") new_item = {**group, "text": text} new_item["id"] = gen_md5_hash(new_item, new_item.keys()) new_item["title"] = str(Path(path).name) return new_item base_dir = config.base_dir - if config.type == "file": + # if config.type == "file": # base dir is already being added to root dir in case of type file. - base_dir = None + # base_dir = None files = list( storage.find( re.compile(config.file_pattern), @@ -67,7 +67,7 @@ async def load_file( for file, group in files: try: - files_loaded.append(await load_file(file, group)) + files_loaded.append(await load_file(base_dir + file, group)) except Exception: # noqa: BLE001 (catching Exception is fine here) log.warning("Warning! Error loading file %s. Skipping...", file) From 2c9edd7b9488bcf26c34515b475966e7ab059792 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Thu, 22 Aug 2024 10:58:46 -0700 Subject: [PATCH 23/87] Resolve final merge conflict --- graphrag/query/cli.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 5ba1c179ed..8e770a107d 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -39,11 +39,8 @@ ) from common.graph_db_client import GraphDBClient -<<<<<<< HEAD -======= reporter = PrintProgressReporter("") ->>>>>>> 16295e5da8a3e66178c9196a6770b9bf08a8add7 reporter = PrintProgressReporter("") From b9fbb6ad9d8de8268b8157ce89b7040bfba863c0 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Thu, 22 Aug 2024 11:23:42 -0700 Subject: [PATCH 24/87] merge final_entities and final_nodes lancedb style --- graphrag/query/cli.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 2128aff077..c8b5ec9840 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -290,7 +290,7 @@ def run_content_store_local_search( # else [] # ) - reports=kt_read_indexer_reports( description_embedding_store, community_level) + reports_result=kt_read_indexer_reports( description_embedding_store, community_level) #TODO KQLify this. I know at least the read_indedxer_reports needs to be done in Kusto. We are joining the community reports & final nodes. # search_engine = get_local_search_engine( @@ -315,9 +315,11 @@ def run_content_store_local_search( # Create entities table similar to read_indexer_entities, but creating that table in Kusto, not in memory. def create_entities_table(description_embedding_store: BaseVectorStore, community_level: int): - description_embedding_store.execute_query(".drop table entities ifexists") #make sure a stale schema doesn't exist - description_embedding_store.execute_query(".set entities <| (create_final_entities | \ - project id,title=name,text=description,vector=description_embeddings)") + description_embedding_store.execute_query(".set-or-replace entities <| ( \ + create_final_nodes | where level <= 2 | project name=['title'] ,rank=degree,community | \ + summarize community=max(community) by name,rank | join kind=inner \ + create_final_entities on name | project id,title=name,text=description,vector=description_embeddings)") + ''' description_embedding_store.execute_query(f".set entities <| create_final_nodes \ | where level <= {community_level} \ From 621ea1131fa37a048fcaf8c57195aa400b42aed5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Thu, 22 Aug 2024 16:24:32 -0700 Subject: [PATCH 25/87] Add config parameters for graphdb --- common/graph_db_client.py | 12 ++++------ graphrag/config/create_graphrag_config.py | 16 ++++++++++++- graphrag/config/models/__init__.py | 2 ++ graphrag/config/models/graph_rag_config.py | 8 ++++++- graphrag/index/config/pipeline.py | 6 +++++ graphrag/index/create_pipeline_config.py | 1 + graphrag/index/emit/factories.py | 9 ++++---- graphrag/index/emit/graph_db_emitter.py | 8 +++++-- graphrag/index/run.py | 4 ++++ graphrag/query/cli.py | 27 ++++++++++++++-------- 10 files changed, 68 insertions(+), 25 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index 31f21c813f..db1d467df2 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -1,6 +1,7 @@ import os import pandas as pd +from graphrag.config.models.graphdb_config import GraphDBConfig import numpy as np import ast @@ -12,15 +13,12 @@ import json class GraphDBClient: - def __init__(self): - ACCOUNT_NAME = os.getenv("ACCOUNT_NAME") - ACCOUNT_KEY = os.getenv("ACCOUNT_KEY") - GRAPHDB_USERNAME = os.getenv("GRAPHDB_USERNAME") + def __init__(self,graph_db_params: GraphDBConfig|None): self._client=client.Client( - url=f"wss://{ACCOUNT_NAME}.gremlin.cosmos.azure.com:443/", + url=f"wss://{graph_db_params.account_name}.gremlin.cosmos.azure.com:443/", traversal_source="g", - username=GRAPHDB_USERNAME, - password=f"{ACCOUNT_KEY}", + username=graph_db_params.username, + password=f"{graph_db_params.account_key}", message_serializer=serializer.GraphSONSerializersV2d0(), ) diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 0903991548..8477916736 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -54,6 +54,7 @@ SummarizeDescriptionsConfig, TextEmbeddingConfig, UmapConfig, + GraphDBConfig, ) from .read_dotenv import read_dotenv @@ -550,6 +551,17 @@ def hydrate_parallelization_params( files=reader.list("files") or [], ) + with ( + reader.use(values.get("graphdb")), + reader.envvar_prefix(Section.query_context), + ): + graphdb_model = GraphDBConfig( + account_name=reader.str("account_name") or None, + account_key=reader.str("account_key") or None, + username=reader.str("username") or None, + enabled=reader.bool("enabled") or False, + ) + encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL skip_workflows = reader.list("skip_workflows") or [] @@ -576,7 +588,8 @@ def hydrate_parallelization_params( skip_workflows=skip_workflows, local_search=local_search_model, global_search=global_search_model, - query_context=query_context_model + query_context=query_context_model, + graphdb=graphdb_model, ) @@ -645,6 +658,7 @@ class Section(str, Enum): local_search = "LOCAL_SEARCH" global_search = "GLOBAL_SEARCH" query_context = "QUERY_CONTEXT" + graphdb = "GRAPHDB" def _is_azure(llm_type: LLMType | None) -> bool: diff --git a/graphrag/config/models/__init__.py b/graphrag/config/models/__init__.py index f2c5185c66..f1d206ef85 100644 --- a/graphrag/config/models/__init__.py +++ b/graphrag/config/models/__init__.py @@ -24,6 +24,7 @@ from .summarize_descriptions_config import SummarizeDescriptionsConfig from .text_embedding_config import TextEmbeddingConfig from .umap_config import UmapConfig +from .graphdb_config import GraphDBConfig __all__ = [ "CacheConfig", @@ -47,4 +48,5 @@ "SummarizeDescriptionsConfig", "TextEmbeddingConfig", "UmapConfig", + "GraphDBConfig", ] diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index ab3cbd9fdd..e7249a9016 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -4,6 +4,7 @@ """Parameterization settings for the default configuration.""" from devtools import pformat +from graphrag.config.models.graphdb_config import GraphDBConfig from pydantic import Field import graphrag.config.defaults as defs @@ -149,4 +150,9 @@ def __str__(self): query_context: QueryContextConfig = Field( description="The query context to use.", default=[] ) - """The query context to use.""" \ No newline at end of file + """The query context to use.""" + + graphdb: GraphDBConfig = Field( + description="The parameters to use graphdb.", default=[] + ) + """The parameters to use graphdb.""" \ No newline at end of file diff --git a/graphrag/index/config/pipeline.py b/graphrag/index/config/pipeline.py index 30d866e349..fc3328b422 100644 --- a/graphrag/index/config/pipeline.py +++ b/graphrag/index/config/pipeline.py @@ -6,6 +6,7 @@ from __future__ import annotations from devtools import pformat +from graphrag.config.models.graphdb_config import GraphDBConfig from pydantic import BaseModel from pydantic import Field as pydantic_Field @@ -62,3 +63,8 @@ def __str__(self): description="The workflows for the pipeline.", default_factory=list ) """The workflows for the pipeline.""" + + graphdb_params: GraphDBConfig|None = pydantic_Field( + description="Parameters for Graphdb collection", default=None + ) + """Parameters for Graphdb collection""" diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index 439f64cd75..9710816444 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -134,6 +134,7 @@ def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineC *_community_workflows(settings, covariates_enabled, embedded_fields), *(_covariate_workflows(settings) if covariates_enabled else []), ], + graphdb_params=settings.graphdb ) # Remove any workflows that were specified to be skipped diff --git a/graphrag/index/emit/factories.py b/graphrag/index/emit/factories.py index cd9e203917..b9417787af 100644 --- a/graphrag/index/emit/factories.py +++ b/graphrag/index/emit/factories.py @@ -3,6 +3,7 @@ """Table Emitter Factories.""" +from graphrag.config.models.graphdb_config import GraphDBConfig from graphrag.index.storage import PipelineStorage from graphrag.index.typing import ErrorHandlerFn @@ -13,9 +14,8 @@ from .table_emitter import TableEmitter from .types import TableEmitterType - def create_table_emitter( - emitter_type: TableEmitterType, storage: PipelineStorage, on_error: ErrorHandlerFn + emitter_type: TableEmitterType, storage: PipelineStorage, on_error: ErrorHandlerFn, graphdb_params: GraphDBConfig|None = None ) -> TableEmitter: """Create a table emitter based on the specified type.""" match emitter_type: @@ -26,7 +26,7 @@ def create_table_emitter( case TableEmitterType.CSV: return CSVTableEmitter(storage) case TableEmitterType.Graphdb: - return GraphDBEmitter() + return GraphDBEmitter(graphdb_params) case _: msg = f"Unsupported table emitter type: {emitter_type}" raise ValueError(msg) @@ -36,9 +36,10 @@ def create_table_emitters( emitter_types: list[TableEmitterType], storage: PipelineStorage, on_error: ErrorHandlerFn, + graphdb_params: GraphDBConfig|None = None, ) -> list[TableEmitter]: """Create a list of table emitters based on the specified types.""" return [ - create_table_emitter(emitter_type, storage, on_error) + create_table_emitter(emitter_type, storage, on_error, graphdb_params) for emitter_type in emitter_types ] diff --git a/graphrag/index/emit/graph_db_emitter.py b/graphrag/index/emit/graph_db_emitter.py index aa8ed7cfb7..5365835e77 100644 --- a/graphrag/index/emit/graph_db_emitter.py +++ b/graphrag/index/emit/graph_db_emitter.py @@ -8,15 +8,19 @@ import pandas as pd +from graphrag.config.models.graphdb_config import GraphDBConfig from gremlin_python.driver import client, serializer from .table_emitter import TableEmitter from common.graph_db_client import GraphDBClient +from graphrag.index.storage import PipelineStorage + class GraphDBEmitter(TableEmitter): - def __init__(self): - self.graph_db_client = GraphDBClient() + + def __init__(self, graph_db_params: GraphDBConfig|None): + self.graph_db_client = GraphDBClient(graph_db_params) self.allowed_workflows = ['create_final_entities','create_final_relationships'] async def emit(self, name: str, data: pd.DataFrame) -> None: diff --git a/graphrag/index/run.py b/graphrag/index/run.py index 94a519de87..7ea6faea2b 100644 --- a/graphrag/index/run.py +++ b/graphrag/index/run.py @@ -24,6 +24,7 @@ WorkflowCallbacksManager, WorkflowRunResult, ) +from graphrag.config.models.graphdb_config import GraphDBConfig from .cache import InMemoryCache, PipelineCache, load_cache from .config import ( @@ -164,6 +165,7 @@ def _create_postprocess_steps( progress_reporter=progress_reporter, emit=emit, is_resume_run=is_resume_run, + graphdb_params = config.graphdb_params, ): yield table @@ -181,6 +183,7 @@ async def run_pipeline( emit: list[TableEmitterType] | None = None, memory_profile: bool = False, is_resume_run: bool = False, + graphdb_params: GraphDBConfig|None = None, **_kwargs: dict, ) -> AsyncIterable[PipelineRunResult]: """Run the pipeline. @@ -216,6 +219,7 @@ async def run_pipeline( lambda e, s, d: cast(WorkflowCallbacks, callbacks).on_error( "Error emitting table", e, s, d ), + graphdb_params ) loaded_workflows = load_workflows( workflows, diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 97781c90b9..400f2dbaf1 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -104,13 +104,19 @@ def run_global_search( data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config_dir ) - graph_db_client = GraphDBClient() + if config.graphdb.enabled: + graph_db_client = GraphDBClient(config.graphdb) data_path = Path(data_dir) final_nodes: pd.DataFrame = pd.read_parquet( data_path / "create_final_nodes.parquet" ) - final_entities = graph_db_client.query_vertices() + if config.graphdb.enabled: + final_entities = graph_db_client.query_vertices() + else: + final_entities: pd.DataFrame = pd.read_parquet( + data_path / "create_final_entities.parquet" + ) final_community_reports: pd.DataFrame = pd.read_parquet( data_path / "create_final_community_reports.parquet" ) @@ -145,7 +151,7 @@ def run_local_search( data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config_dir ) - + data_paths = [] data_paths = get_files_by_contextid(config, context_id) #data_paths = [Path("E:\\graphrag\\ragtest6\\output\\AtoG\\artifacts")] @@ -157,7 +163,8 @@ def run_local_search( final_relationships = pd.DataFrame() final_entities = pd.DataFrame() final_covariates = pd.DataFrame() - graph_db_client = GraphDBClient() + if config.graphdb.enabled: + graph_db_client = GraphDBClient(config.graphdb) for data_path in data_paths: #check from the config for the ouptut storage type and then read the data from the storage. @@ -168,12 +175,12 @@ def run_local_search( final_text_units = pd.concat([final_text_units, read_paraquet_file(config, data_path + "/create_final_text_units.parquet", config.storage.type)]) - #final_relationships = pd.concat([final_relationships, read_paraquet_file(config, data_path + "/create_final_relationships.parquet", config.storage.type)]) - final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) - - - #final_entities = pd.concat([final_entities, read_paraquet_file(config, data_path + "/create_final_entities.parquet", config.storage.type)]) - final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) + if config.graphdb.enabled: + final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) + final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) + else: + final_relationships = pd.concat([final_relationships, read_paraquet_file(config, data_path + "/create_final_relationships.parquet", config.storage.type)]) + final_entities = pd.concat([final_entities, read_paraquet_file(config, data_path + "/create_final_entities.parquet", config.storage.type)]) data_path_object = Path(data_path) final_covariates_path = data_path_object / "create_final_covariates.parquet" From cae28b3080551991137d84c045dbef5c92064f2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Fri, 23 Aug 2024 11:49:32 -0700 Subject: [PATCH 26/87] Add default values in graphrag/index/init_content.py --- graphrag/index/init_content.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/graphrag/index/init_content.py b/graphrag/index/init_content.py index af4e11b606..5214abcc85 100644 --- a/graphrag/index/init_content.py +++ b/graphrag/index/init_content.py @@ -160,7 +160,13 @@ # concurrency: {defs.GLOBAL_SEARCH_CONCURRENCY} query_context: - # Files: [] # list of files in context to run query + # Files: [] # list of files in context to run query + +graphdb: + account_name: '' + account_key: '' + username: '' + enabled: false """ INIT_DOTENV = """ From d5a8e7cf46c9911d196021fd9c08602949c0d530 Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Fri, 23 Aug 2024 14:49:37 -0700 Subject: [PATCH 27/87] Moving Pipeline Storage to Common + Export query artifacts --- graphrag/{index => common}/config/storage.py | 0 .../{index => common}/progress/__init__.py | 0 graphrag/{index => common}/progress/rich.py | 0 graphrag/{index => common}/progress/types.py | 0 .../{index => common}/storage/__init__.py | 0 .../storage/blob_pipeline_storage.py | 2 +- .../storage/file_pipeline_storage.py | 3 +- .../{index => common}/storage/load_storage.py | 2 +- .../storage/memory_pipeline_storage.py | 0 graphrag/{index => common}/storage/typing.py | 2 +- graphrag/index/__init__.py | 2 +- graphrag/index/cache/json_pipeline_cache.py | 2 +- graphrag/index/cache/load_cache.py | 2 +- graphrag/index/cli.py | 4 +- graphrag/index/config/__init__.py | 2 +- graphrag/index/config/pipeline.py | 2 +- graphrag/index/context.py | 2 +- graphrag/index/create_pipeline_config.py | 2 +- graphrag/index/emit/csv_table_emitter.py | 2 +- graphrag/index/emit/factories.py | 2 +- graphrag/index/emit/json_table_emitter.py | 2 +- graphrag/index/emit/parquet_table_emitter.py | 2 +- graphrag/index/input/csv.py | 4 +- graphrag/index/input/load_input.py | 4 +- graphrag/index/input/text.py | 14 ++-- .../reporting/progress_workflow_callbacks.py | 2 +- graphrag/index/run.py | 4 +- graphrag/index/verbs/snapshot.py | 2 +- graphrag/index/verbs/snapshot_rows.py | 2 +- graphrag/query/cli.py | 65 +++++++++---------- .../local_search/mixed_context.py | 52 ++++++++------- .../structured_search/local_search/search.py | 39 +++++++++++ tests/smoke/test_fixtures.py | 2 +- .../cache/test_file_pipeline_cache.py | 2 +- .../storage/test_blob_pipeline_storage.py | 2 +- .../storage/test_file_pipeline_storage.py | 2 +- 36 files changed, 135 insertions(+), 96 deletions(-) rename graphrag/{index => common}/config/storage.py (100%) rename graphrag/{index => common}/progress/__init__.py (100%) rename graphrag/{index => common}/progress/rich.py (100%) rename graphrag/{index => common}/progress/types.py (100%) rename graphrag/{index => common}/storage/__init__.py (100%) rename graphrag/{index => common}/storage/blob_pipeline_storage.py (99%) rename graphrag/{index => common}/storage/file_pipeline_storage.py (97%) rename graphrag/{index => common}/storage/load_storage.py (96%) rename graphrag/{index => common}/storage/memory_pipeline_storage.py (100%) rename graphrag/{index => common}/storage/typing.py (97%) diff --git a/graphrag/index/config/storage.py b/graphrag/common/config/storage.py similarity index 100% rename from graphrag/index/config/storage.py rename to graphrag/common/config/storage.py diff --git a/graphrag/index/progress/__init__.py b/graphrag/common/progress/__init__.py similarity index 100% rename from graphrag/index/progress/__init__.py rename to graphrag/common/progress/__init__.py diff --git a/graphrag/index/progress/rich.py b/graphrag/common/progress/rich.py similarity index 100% rename from graphrag/index/progress/rich.py rename to graphrag/common/progress/rich.py diff --git a/graphrag/index/progress/types.py b/graphrag/common/progress/types.py similarity index 100% rename from graphrag/index/progress/types.py rename to graphrag/common/progress/types.py diff --git a/graphrag/index/storage/__init__.py b/graphrag/common/storage/__init__.py similarity index 100% rename from graphrag/index/storage/__init__.py rename to graphrag/common/storage/__init__.py diff --git a/graphrag/index/storage/blob_pipeline_storage.py b/graphrag/common/storage/blob_pipeline_storage.py similarity index 99% rename from graphrag/index/storage/blob_pipeline_storage.py rename to graphrag/common/storage/blob_pipeline_storage.py index 8501734f6d..6acc761e5c 100644 --- a/graphrag/index/storage/blob_pipeline_storage.py +++ b/graphrag/common/storage/blob_pipeline_storage.py @@ -13,7 +13,7 @@ from azure.storage.blob import BlobServiceClient from datashaper import Progress -from graphrag.index.progress import ProgressReporter +from graphrag.common.progress import ProgressReporter from .typing import PipelineStorage diff --git a/graphrag/index/storage/file_pipeline_storage.py b/graphrag/common/storage/file_pipeline_storage.py similarity index 97% rename from graphrag/index/storage/file_pipeline_storage.py rename to graphrag/common/storage/file_pipeline_storage.py index ee61bab3dd..212783e41f 100644 --- a/graphrag/index/storage/file_pipeline_storage.py +++ b/graphrag/common/storage/file_pipeline_storage.py @@ -16,7 +16,7 @@ from aiofiles.ospath import exists from datashaper import Progress -from graphrag.index.progress import ProgressReporter +from graphrag.common.progress import ProgressReporter from .typing import PipelineStorage @@ -113,6 +113,7 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: is_bytes = isinstance(value, bytes) write_type = "wb" if is_bytes else "w" encoding = None if is_bytes else encoding or self._encoding + os.makedirs(os.path.dirname(join_path(self._root_dir, key)), mode=777, exist_ok=True) async with aiofiles.open( join_path(self._root_dir, key), cast(Any, write_type), diff --git a/graphrag/index/storage/load_storage.py b/graphrag/common/storage/load_storage.py similarity index 96% rename from graphrag/index/storage/load_storage.py rename to graphrag/common/storage/load_storage.py index 33d61ee97f..24a6675a04 100644 --- a/graphrag/index/storage/load_storage.py +++ b/graphrag/common/storage/load_storage.py @@ -8,7 +8,7 @@ from typing import cast from graphrag.config import StorageType -from graphrag.index.config.storage import ( +from graphrag.common.config.storage import ( PipelineBlobStorageConfig, PipelineFileStorageConfig, PipelineStorageConfig, diff --git a/graphrag/index/storage/memory_pipeline_storage.py b/graphrag/common/storage/memory_pipeline_storage.py similarity index 100% rename from graphrag/index/storage/memory_pipeline_storage.py rename to graphrag/common/storage/memory_pipeline_storage.py diff --git a/graphrag/index/storage/typing.py b/graphrag/common/storage/typing.py similarity index 97% rename from graphrag/index/storage/typing.py rename to graphrag/common/storage/typing.py index 595baf4efd..c5f0de3265 100644 --- a/graphrag/index/storage/typing.py +++ b/graphrag/common/storage/typing.py @@ -8,7 +8,7 @@ from collections.abc import Iterator from typing import Any -from graphrag.index.progress import ProgressReporter +from graphrag.common.progress import ProgressReporter class PipelineStorage(metaclass=ABCMeta): diff --git a/graphrag/index/__init__.py b/graphrag/index/__init__.py index c97c290a94..38ab263620 100644 --- a/graphrag/index/__init__.py +++ b/graphrag/index/__init__.py @@ -38,7 +38,7 @@ ) from .load_pipeline_config import load_pipeline_config from .run import run_pipeline, run_pipeline_with_config -from .storage import PipelineStorage +from graphrag.common.storage import PipelineStorage __all__ = [ "NoWorkflowsDefinedError", diff --git a/graphrag/index/cache/json_pipeline_cache.py b/graphrag/index/cache/json_pipeline_cache.py index b88a38990c..30e73fedc6 100644 --- a/graphrag/index/cache/json_pipeline_cache.py +++ b/graphrag/index/cache/json_pipeline_cache.py @@ -6,7 +6,7 @@ import json from typing import Any -from graphrag.index.storage import PipelineStorage +from graphrag.common.storage.typing import PipelineStorage from .pipeline_cache import PipelineCache diff --git a/graphrag/index/cache/load_cache.py b/graphrag/index/cache/load_cache.py index 4e0e6324fb..1a97b2e4de 100644 --- a/graphrag/index/cache/load_cache.py +++ b/graphrag/index/cache/load_cache.py @@ -12,7 +12,7 @@ PipelineBlobCacheConfig, PipelineFileCacheConfig, ) -from graphrag.index.storage import BlobPipelineStorage, FilePipelineStorage +from graphrag.common.storage import BlobPipelineStorage, FilePipelineStorage if TYPE_CHECKING: from graphrag.index.config import ( diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index cc557a890a..0a370f42a7 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -20,12 +20,12 @@ from graphrag.common.utils.common_utils import is_valid_guid from graphrag.index import PipelineConfig, create_pipeline_config from graphrag.index.cache import NoopPipelineCache -from graphrag.index.progress import ( +from graphrag.common.progress import ( NullProgressReporter, PrintProgressReporter, ProgressReporter, ) -from graphrag.index.progress.rich import RichProgressReporter +from graphrag.common.progress.rich import RichProgressReporter from graphrag.index.run import run_pipeline_with_config from .emit import TableEmitterType diff --git a/graphrag/index/config/__init__.py b/graphrag/index/config/__init__.py index 3c40762a84..ad30859b81 100644 --- a/graphrag/index/config/__init__.py +++ b/graphrag/index/config/__init__.py @@ -25,7 +25,7 @@ PipelineReportingConfig, PipelineReportingConfigTypes, ) -from .storage import ( +from ...common.config.storage import ( PipelineBlobStorageConfig, PipelineFileStorageConfig, PipelineMemoryStorageConfig, diff --git a/graphrag/index/config/pipeline.py b/graphrag/index/config/pipeline.py index fc3328b422..e8bbbdbf4c 100644 --- a/graphrag/index/config/pipeline.py +++ b/graphrag/index/config/pipeline.py @@ -13,7 +13,7 @@ from .cache import PipelineCacheConfigTypes from .input import PipelineInputConfigTypes from .reporting import PipelineReportingConfigTypes -from .storage import PipelineStorageConfigTypes +from ...common.config.storage import PipelineStorageConfigTypes from .workflow import PipelineWorkflowReference diff --git a/graphrag/index/context.py b/graphrag/index/context.py index e74799bd35..cdec0f6292 100644 --- a/graphrag/index/context.py +++ b/graphrag/index/context.py @@ -8,7 +8,7 @@ from dataclasses import field from .cache import PipelineCache -from .storage.typing import PipelineStorage +from graphrag.common.storage.typing import PipelineStorage @dc_dataclass diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index 9710816444..a0ae4f9b3c 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -40,7 +40,7 @@ PipelineFileReportingConfig, PipelineReportingConfigTypes, ) -from graphrag.index.config.storage import ( +from graphrag.common.config.storage import ( PipelineBlobStorageConfig, PipelineFileStorageConfig, PipelineMemoryStorageConfig, diff --git a/graphrag/index/emit/csv_table_emitter.py b/graphrag/index/emit/csv_table_emitter.py index c0305c254b..0c208d1264 100644 --- a/graphrag/index/emit/csv_table_emitter.py +++ b/graphrag/index/emit/csv_table_emitter.py @@ -7,7 +7,7 @@ import pandas as pd -from graphrag.index.storage import PipelineStorage +from graphrag.common.storage import PipelineStorage from .table_emitter import TableEmitter diff --git a/graphrag/index/emit/factories.py b/graphrag/index/emit/factories.py index b9417787af..edabba4506 100644 --- a/graphrag/index/emit/factories.py +++ b/graphrag/index/emit/factories.py @@ -4,7 +4,7 @@ """Table Emitter Factories.""" from graphrag.config.models.graphdb_config import GraphDBConfig -from graphrag.index.storage import PipelineStorage +from graphrag.common.storage import PipelineStorage from graphrag.index.typing import ErrorHandlerFn from .csv_table_emitter import CSVTableEmitter diff --git a/graphrag/index/emit/json_table_emitter.py b/graphrag/index/emit/json_table_emitter.py index 0b18c717a6..39f936b781 100644 --- a/graphrag/index/emit/json_table_emitter.py +++ b/graphrag/index/emit/json_table_emitter.py @@ -7,7 +7,7 @@ import pandas as pd -from graphrag.index.storage import PipelineStorage +from graphrag.common.storage import PipelineStorage from .table_emitter import TableEmitter diff --git a/graphrag/index/emit/parquet_table_emitter.py b/graphrag/index/emit/parquet_table_emitter.py index 753915a79a..aa6dd38f96 100644 --- a/graphrag/index/emit/parquet_table_emitter.py +++ b/graphrag/index/emit/parquet_table_emitter.py @@ -9,7 +9,7 @@ import pandas as pd from pyarrow.lib import ArrowInvalid, ArrowTypeError -from graphrag.index.storage import PipelineStorage +from graphrag.common.storage import PipelineStorage from graphrag.index.typing import ErrorHandlerFn from .table_emitter import TableEmitter diff --git a/graphrag/index/input/csv.py b/graphrag/index/input/csv.py index 2e4864a98c..04f43ddda0 100644 --- a/graphrag/index/input/csv.py +++ b/graphrag/index/input/csv.py @@ -11,8 +11,8 @@ import pandas as pd from graphrag.index.config import PipelineCSVInputConfig, PipelineInputConfig -from graphrag.index.progress import ProgressReporter -from graphrag.index.storage import PipelineStorage +from graphrag.common.progress import ProgressReporter +from graphrag.common.storage import PipelineStorage from graphrag.index.utils import gen_md5_hash log = logging.getLogger(__name__) diff --git a/graphrag/index/input/load_input.py b/graphrag/index/input/load_input.py index 6d62334210..f4aaa5cf5a 100644 --- a/graphrag/index/input/load_input.py +++ b/graphrag/index/input/load_input.py @@ -12,8 +12,8 @@ from graphrag.config import InputConfig, InputType from graphrag.index.config import PipelineInputConfig -from graphrag.index.progress import NullProgressReporter, ProgressReporter -from graphrag.index.storage import ( +from graphrag.common.progress import NullProgressReporter, ProgressReporter +from graphrag.common.storage import ( BlobPipelineStorage, FilePipelineStorage, ) diff --git a/graphrag/index/input/text.py b/graphrag/index/input/text.py index c9609d5673..51a5da9d82 100644 --- a/graphrag/index/input/text.py +++ b/graphrag/index/input/text.py @@ -11,8 +11,8 @@ import pandas as pd from graphrag.index.config import PipelineInputConfig -from graphrag.index.progress import ProgressReporter -from graphrag.index.storage import PipelineStorage +from graphrag.common.progress import ProgressReporter +from graphrag.common.storage import PipelineStorage from graphrag.index.utils import gen_md5_hash DEFAULT_FILE_PATTERN = re.compile( @@ -34,15 +34,15 @@ async def load_file( ) -> dict[str, Any]: if group is None: group = {} - text = await storage.get(path.replace('\\','/'), encoding="utf-8") + text = await storage.get(path, encoding="utf-8") new_item = {**group, "text": text} new_item["id"] = gen_md5_hash(new_item, new_item.keys()) new_item["title"] = str(Path(path).name) return new_item base_dir = config.base_dir - # if config.type == "file": - # base dir is already being added to root dir in case of type file. - # base_dir = None + if config.type == "file": + #base dir is already being added to root dir in case of type file. + base_dir = None files = list( storage.find( re.compile(config.file_pattern), @@ -67,7 +67,7 @@ async def load_file( for file, group in files: try: - files_loaded.append(await load_file(base_dir + file, group)) + files_loaded.append(await load_file(file, group)) except Exception: # noqa: BLE001 (catching Exception is fine here) log.warning("Warning! Error loading file %s. Skipping...", file) diff --git a/graphrag/index/reporting/progress_workflow_callbacks.py b/graphrag/index/reporting/progress_workflow_callbacks.py index 68f10d7530..68e407e223 100644 --- a/graphrag/index/reporting/progress_workflow_callbacks.py +++ b/graphrag/index/reporting/progress_workflow_callbacks.py @@ -7,7 +7,7 @@ from datashaper import ExecutionNode, NoopWorkflowCallbacks, Progress, TableContainer -from graphrag.index.progress import ProgressReporter +from graphrag.common.progress import ProgressReporter class ProgressWorkflowCallbacks(NoopWorkflowCallbacks): diff --git a/graphrag/index/run.py b/graphrag/index/run.py index 7ea6faea2b..7ca8bb9264 100644 --- a/graphrag/index/run.py +++ b/graphrag/index/run.py @@ -47,13 +47,13 @@ from .emit import TableEmitterType, create_table_emitters from .input import load_input from .load_pipeline_config import load_pipeline_config -from .progress import NullProgressReporter, ProgressReporter +from graphrag.common.progress import NullProgressReporter, ProgressReporter from .reporting import ( ConsoleWorkflowCallbacks, ProgressWorkflowCallbacks, load_pipeline_reporter, ) -from .storage import MemoryPipelineStorage, PipelineStorage, load_storage +from graphrag.common.storage import MemoryPipelineStorage, PipelineStorage, load_storage from .typing import PipelineRunResult # Register all verbs diff --git a/graphrag/index/verbs/snapshot.py b/graphrag/index/verbs/snapshot.py index a90fc2837b..b781478532 100644 --- a/graphrag/index/verbs/snapshot.py +++ b/graphrag/index/verbs/snapshot.py @@ -5,7 +5,7 @@ from datashaper import TableContainer, VerbInput, verb -from graphrag.index.storage import PipelineStorage +from graphrag.common.storage import PipelineStorage @verb(name="snapshot") diff --git a/graphrag/index/verbs/snapshot_rows.py b/graphrag/index/verbs/snapshot_rows.py index 99aae70a04..0b0ca1c3b6 100644 --- a/graphrag/index/verbs/snapshot_rows.py +++ b/graphrag/index/verbs/snapshot_rows.py @@ -9,7 +9,7 @@ from datashaper import TableContainer, VerbInput, verb -from graphrag.index.storage import PipelineStorage +from graphrag.common.storage import PipelineStorage @dataclass diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 400f2dbaf1..20daa2feb8 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -3,10 +3,12 @@ """Command line interface for the query module.""" +import asyncio import os from pathlib import Path from typing import cast from io import BytesIO +from graphrag.common.storage import PipelineStorage, BlobPipelineStorage, FilePipelineStorage from graphrag.common.utils.context_utils import get_files_by_contextid from graphrag.config.enums import StorageType from azure.core.exceptions import ResourceNotFoundError @@ -17,7 +19,7 @@ create_graphrag_config, GraphRagConfig, ) -from graphrag.index.progress import PrintProgressReporter +from graphrag.common.progress import PrintProgressReporter from graphrag.model.entity import Entity from graphrag.query.input.loaders.dfs import ( store_entity_semantic_embeddings, @@ -26,8 +28,6 @@ from graphrag.vector_stores.base import BaseVectorStore from graphrag.vector_stores.lancedb import LanceDBVectorStore from graphrag.vector_stores.kusto import KustoVectorStore -from graphrag.common.blob_storage_client import BlobStorageClient - from .factories import get_global_search_engine, get_local_search_engine from .indexer_adapters import ( read_indexer_covariates, @@ -152,11 +152,21 @@ def run_local_search( data_dir, root_dir, config_dir ) + # for the POC purpose input artifacts blob, output artifacts blob and input query blob storage are going to same. + if(config.storage.type == StorageType.memory): + ValueError("Memory storage is not supported") + if(config.storage.type == StorageType.blob): + if(config.storage.container_name is not None): + input_storage_client: PipelineStorage = BlobPipelineStorage(config.storage.connection_string, config.storage.container_name) + output_storage_client: PipelineStorage = BlobPipelineStorage(config.storage.connection_string, config.storage.container_name) + else: + ValueError("Storage type is Blob but container name is invalid") + if(config.storage.type == StorageType.file): + input_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + output_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + data_paths = [] data_paths = get_files_by_contextid(config, context_id) - #data_paths = [Path("E:\\graphrag\\ragtest6\\output\\AtoG\\artifacts")] - #data_paths = [Path("E:\\graphrag\\auditlogstest\\output\\securityPlatformPPE\\artifacts"),Path("E:\\graphrag\\auditlogstest\\output\\UnifiedFeedbackPPE\\artifacts")] - #data_paths.append(Path(data_dir)) final_nodes = pd.DataFrame() final_community_reports = pd.DataFrame() final_text_units = pd.DataFrame() @@ -169,25 +179,20 @@ def run_local_search( #check from the config for the ouptut storage type and then read the data from the storage. #GraphDB: we may need to make change below to read nodes data from Graph DB - final_nodes = pd.concat([final_nodes, read_paraquet_file(config, data_path + "/create_final_nodes.parquet", config.storage.type)]) + final_nodes = pd.concat([final_nodes, read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet")]) - final_community_reports = pd.concat([final_community_reports,read_paraquet_file(config, data_path + "/create_final_community_reports.parquet", config.storage.type)]) + final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) - final_text_units = pd.concat([final_text_units, read_paraquet_file(config, data_path + "/create_final_text_units.parquet", config.storage.type)]) + final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) if config.graphdb.enabled: final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) else: - final_relationships = pd.concat([final_relationships, read_paraquet_file(config, data_path + "/create_final_relationships.parquet", config.storage.type)]) - final_entities = pd.concat([final_entities, read_paraquet_file(config, data_path + "/create_final_entities.parquet", config.storage.type)]) - - data_path_object = Path(data_path) - final_covariates_path = data_path_object / "create_final_covariates.parquet" + final_relationships = pd.concat([final_relationships, read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) + final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) - final_covariates = pd.concat([final_covariates, ( - read_paraquet_file(config, final_covariates_path, config.storage.type) if final_covariates_path.exists() else None - )]) + final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} @@ -221,7 +226,9 @@ def run_local_search( response_type=response_type, ) - result = search_engine.search(query=query) + result = search_engine.optimized_search(query=query) # changed it to search if we want to get final text response. + for key in result.context_data.keys(): + asyncio.run(output_storage_client.set("query/output/"+ key +".paraquet", result.context_data[key].to_parquet())) #it shows as error in editor but not an error. reporter.success(f"Local Search Response: {result.response}") return result.response @@ -236,23 +243,13 @@ def blob_exists(container_client, blob_name): return False -def read_paraquet_file(config:GraphRagConfig, path: str, storageType: StorageType): +def read_paraquet_file(storage: PipelineStorage, path: str): #create different enum for paraquet storage type - if storageType == StorageType.blob: - container_name = config.input.container_name or "" - blobStorageClient = BlobStorageClient(connection_string=config.input.connection_string, container_name=container_name, encoding="utf-8") - container_client = blobStorageClient.get_container_client() - if blob_exists(container_client, path): - blob_data = container_client.download_blob(blob=path) - bytes_io = BytesIO(blob_data.readall()) - return pd.read_parquet(bytes_io, engine="pyarrow") - else: - return pd.DataFrame() # return empty data frame as covariates file doesn't exist - else: - file_path = Path(path) - if not file_path.exists(): - return pd.DataFrame() - return pd.read_parquet(path) + file_data = asyncio.run(storage.get(path, True)) + if file_data is None: + return pd.DataFrame() + return pd.read_parquet(BytesIO(file_data), engine="pyarrow") + # TODO I split this out for now to preserve how the original local search worked. # I don't think this will necessarily be permanently separate. # It was just easier without having to keep everything generic and work the same way as local search worked. diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index 117101e913..cc6ba2d150 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -111,6 +111,7 @@ def build_context( min_community_rank: int = 0, community_context_name: str = "Reports", column_delimiter: str = "|", + is_optimized_flow: bool = False, **kwargs: dict[str, Any], ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: """ @@ -172,20 +173,21 @@ def build_context( ) # build community context - community_tokens = max(int(max_tokens * community_prop), 0) - community_context, community_context_data = self._build_community_context( - selected_entities=selected_entities, - max_tokens=community_tokens, - use_community_summary=use_community_summary, - column_delimiter=column_delimiter, - include_community_rank=include_community_rank, - min_community_rank=min_community_rank, - return_candidate_context=return_candidate_context, - context_name=community_context_name, - ) - if community_context.strip() != "": - final_context.append(community_context) - final_context_data = {**final_context_data, **community_context_data} + if not is_optimized_flow: + community_tokens = max(int(max_tokens * community_prop), 0) + community_context, community_context_data = self._build_community_context( + selected_entities=selected_entities, + max_tokens=community_tokens, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + return_candidate_context=return_candidate_context, + context_name=community_context_name, + ) + if community_context.strip() != "": + final_context.append(community_context) + final_context_data = {**final_context_data, **community_context_data} # build local (i.e. entity-relationship-covariate) context local_prop = 1 - community_prop - text_unit_prop @@ -204,17 +206,17 @@ def build_context( if local_context.strip() != "": final_context.append(str(local_context)) final_context_data = {**final_context_data, **local_context_data} - - # build text unit context - text_unit_tokens = max(int(max_tokens * text_unit_prop), 0) - text_unit_context, text_unit_context_data = self._build_text_unit_context( - selected_entities=selected_entities, - max_tokens=text_unit_tokens, - return_candidate_context=return_candidate_context, - ) - if text_unit_context.strip() != "": - final_context.append(text_unit_context) - final_context_data = {**final_context_data, **text_unit_context_data} + if not is_optimized_flow: + # build text unit context + text_unit_tokens = max(int(max_tokens * text_unit_prop), 0) + text_unit_context, text_unit_context_data = self._build_text_unit_context( + selected_entities=selected_entities, + max_tokens=text_unit_tokens, + return_candidate_context=return_candidate_context, + ) + if text_unit_context.strip() != "": + final_context.append(text_unit_context) + final_context_data = {**final_context_data, **text_unit_context_data} return ("\n\n".join(final_context), final_context_data) diff --git a/graphrag/query/structured_search/local_search/search.py b/graphrag/query/structured_search/local_search/search.py index 80dd667004..40d94b906e 100644 --- a/graphrag/query/structured_search/local_search/search.py +++ b/graphrag/query/structured_search/local_search/search.py @@ -157,3 +157,42 @@ def search( llm_calls=1, prompt_tokens=num_tokens(search_prompt, self.token_encoder), ) + + + def optimized_search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Build local search context data.""" + start_time = time.time() + search_prompt = "" + context_text, context_records = self.context_builder.build_context( + query=query, + conversation_history=conversation_history, + **kwargs, + **self.context_builder_params, + isOptimizedFlow=True, + ) + log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query) + try: + return SearchResult( + response="", + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in _map_response_single_batch") + return SearchResult( + response="", + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index b5118c6a50..bfe8a20b68 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -15,7 +15,7 @@ import pandas as pd import pytest -from graphrag.index.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag.common.storage.blob_pipeline_storage import BlobPipelineStorage log = logging.getLogger(__name__) diff --git a/tests/unit/indexing/cache/test_file_pipeline_cache.py b/tests/unit/indexing/cache/test_file_pipeline_cache.py index ada3239602..a5d78452d8 100644 --- a/tests/unit/indexing/cache/test_file_pipeline_cache.py +++ b/tests/unit/indexing/cache/test_file_pipeline_cache.py @@ -7,7 +7,7 @@ from graphrag.index.cache import ( JsonPipelineCache, ) -from graphrag.index.storage.file_pipeline_storage import ( +from graphrag.common.storage.file_pipeline_storage import ( FilePipelineStorage, ) diff --git a/tests/unit/indexing/storage/test_blob_pipeline_storage.py b/tests/unit/indexing/storage/test_blob_pipeline_storage.py index d2ea868347..62e631c939 100644 --- a/tests/unit/indexing/storage/test_blob_pipeline_storage.py +++ b/tests/unit/indexing/storage/test_blob_pipeline_storage.py @@ -4,7 +4,7 @@ import re -from graphrag.index.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag.common.storage.blob_pipeline_storage import BlobPipelineStorage # cspell:disable-next-line well-known-key WELL_KNOWN_BLOB_STORAGE_KEY = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;" diff --git a/tests/unit/indexing/storage/test_file_pipeline_storage.py b/tests/unit/indexing/storage/test_file_pipeline_storage.py index 643c2a9f69..36315a17c2 100644 --- a/tests/unit/indexing/storage/test_file_pipeline_storage.py +++ b/tests/unit/indexing/storage/test_file_pipeline_storage.py @@ -6,7 +6,7 @@ import re from pathlib import Path -from graphrag.index.storage.file_pipeline_storage import FilePipelineStorage +from graphrag.common.storage.file_pipeline_storage import FilePipelineStorage __dirname__ = os.path.dirname(__file__) From 8e2063789d353785f9883dc9b948e18c5b1acad7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Fri, 23 Aug 2024 15:13:25 -0700 Subject: [PATCH 28/87] Add missing config file --- graphrag/config/models/graphdb_config.py | 27 ++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 graphrag/config/models/graphdb_config.py diff --git a/graphrag/config/models/graphdb_config.py b/graphrag/config/models/graphdb_config.py new file mode 100644 index 0000000000..9be8c8751c --- /dev/null +++ b/graphrag/config/models/graphdb_config.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class GraphDBConfig(BaseModel): + account_name: str|None = Field( + description="Graphdb account name", + default=None + ) + account_key: str|None = Field( + description="Graphdb account key", + default=None + ) + username: str|None = Field( + description="Graphdb username", + default=None + ) + enabled: bool = Field( + description="Flag to enable querying into graphdb", + default=False + ) \ No newline at end of file From c46d71cfaf0b22d2117e4f1ad7d40ba592235e9f Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Fri, 23 Aug 2024 15:41:21 -0700 Subject: [PATCH 29/87] Small fixes to incline with GraphDBClient --- graphrag/index/emit/graph_db_emitter.py | 14 +++----------- graphrag/query/cli.py | 2 +- .../query/structured_search/local_search/search.py | 2 +- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/graphrag/index/emit/graph_db_emitter.py b/graphrag/index/emit/graph_db_emitter.py index 5365835e77..4bb4244fa4 100644 --- a/graphrag/index/emit/graph_db_emitter.py +++ b/graphrag/index/emit/graph_db_emitter.py @@ -3,21 +3,13 @@ """GraphDBEmitter module.""" -import logging -import traceback - import pandas as pd - -from graphrag.config.models.graphdb_config import GraphDBConfig -from gremlin_python.driver import client, serializer - -from .table_emitter import TableEmitter - from common.graph_db_client import GraphDBClient - -from graphrag.index.storage import PipelineStorage +from .table_emitter import TableEmitter +from graphrag.config.models.graphdb_config import GraphDBConfig class GraphDBEmitter(TableEmitter): + """Graph DB Emitter.""" def __init__(self, graph_db_params: GraphDBConfig|None): self.graph_db_client = GraphDBClient(graph_db_params) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 20daa2feb8..a0c648e43c 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -190,7 +190,7 @@ def run_local_search( final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) else: final_relationships = pd.concat([final_relationships, read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) - final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) + final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) diff --git a/graphrag/query/structured_search/local_search/search.py b/graphrag/query/structured_search/local_search/search.py index 40d94b906e..a6fe4117e1 100644 --- a/graphrag/query/structured_search/local_search/search.py +++ b/graphrag/query/structured_search/local_search/search.py @@ -173,7 +173,7 @@ def optimized_search( conversation_history=conversation_history, **kwargs, **self.context_builder_params, - isOptimizedFlow=True, + is_optimized_flow=True, ) log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query) try: From b9dff03a4e265d1c745a3ac3940d3e2023f517ae Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Fri, 23 Aug 2024 18:32:15 -0700 Subject: [PATCH 30/87] saving changes --- graphrag/query/__main__.py | 8 ++ graphrag/query/cli.py | 7 +- .../context_builder/community_context.py | 6 +- .../query/context_builder/local_context.py | 10 +- graphrag/query/factories.py | 2 + .../local_search/mixed_context.py | 112 ++++++++++-------- .../structured_search/local_search/search.py | 3 +- 7 files changed, 88 insertions(+), 60 deletions(-) diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index c6b34a78f4..919c39cb28 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -78,6 +78,13 @@ def __str__(self): default="00000000-0000-0000-0000-000000000000", ) + parser.add_argument( + "--optimized_search", + help="Runs optimized search and export artifacts", + type=bool, + default=False, + ) + parser.add_argument( "query", nargs=1, @@ -99,6 +106,7 @@ def __str__(self): args.response_type, args.context_id, args.query[0], + optimized_search=args.optimized_search ) case SearchType.GLOBAL: run_global_search( diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index a0c648e43c..627d5e44bc 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -146,6 +146,7 @@ def run_local_search( response_type: str, context_id: str, query: str, + optimized_search: bool = False, ): """Run a local search with the given query.""" data_dir, root_dir, config = _configure_paths_and_settings( @@ -224,9 +225,13 @@ def run_local_search( covariates={"claims": covariates}, description_embedding_store=description_embedding_store, response_type=response_type, + is_optimized_search=optimized_search, ) - result = search_engine.optimized_search(query=query) # changed it to search if we want to get final text response. + if optimized_search: + result = search_engine.optimized_search(query=query) + else: + result = search_engine.search(query=query) for key in result.context_data.keys(): asyncio.run(output_storage_client.set("query/output/"+ key +".paraquet", result.context_data[key].to_parquet())) #it shows as error in editor but not an error. reporter.success(f"Local Search Response: {result.response}") diff --git a/graphrag/query/context_builder/community_context.py b/graphrag/query/context_builder/community_context.py index 398f8ac422..ad345a2704 100644 --- a/graphrag/query/context_builder/community_context.py +++ b/graphrag/query/context_builder/community_context.py @@ -33,6 +33,7 @@ def build_community_context( single_batch: bool = True, context_name: str = "Reports", random_state: int = 86, + is_optimized_search: bool = False, ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: """ Prepare community report data table as context data for system prompt. @@ -57,7 +58,8 @@ def _get_header(attributes: list[str]) -> list[str]: return header def _report_context_text( - report: CommunityReport, attributes: list[str] + report: CommunityReport, attributes: list[str], + is_optimized_search: bool = False ) -> tuple[str, list[str]]: context: list[str] = [ report.short_id if report.short_id else "", @@ -143,7 +145,7 @@ def _cut_batch() -> None: _init_batch() for report in selected_reports: - new_context_text, new_context = _report_context_text(report, attributes) + new_context_text, new_context = _report_context_text(report, attributes, is_optimized_search) new_tokens = num_tokens(new_context_text, token_encoder) if batch_tokens + new_tokens > max_tokens: diff --git a/graphrag/query/context_builder/local_context.py b/graphrag/query/context_builder/local_context.py index c2f7527e33..89c2559562 100644 --- a/graphrag/query/context_builder/local_context.py +++ b/graphrag/query/context_builder/local_context.py @@ -33,6 +33,7 @@ def build_entity_context( rank_description: str = "number of relationships", column_delimiter: str = "|", context_name="Entities", + is_optimized_search: bool = False ) -> tuple[str, pd.DataFrame]: """Prepare entity data table as context data for system prompt.""" if len(selected_entities) == 0: @@ -68,10 +69,11 @@ def build_entity_context( else "" ) new_context.append(field_value) - new_context_text = column_delimiter.join(new_context) + "\n" - new_tokens = num_tokens(new_context_text, token_encoder) - if current_tokens + new_tokens > max_tokens: - break + if not is_optimized_search: + new_context_text = column_delimiter.join(new_context) + "\n" + new_tokens = num_tokens(new_context_text, token_encoder) + if current_tokens + new_tokens > max_tokens: + break current_context_text += new_context_text all_context_records.append(new_context) current_tokens += new_tokens diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index 8b6d58fb7e..3dfe104230 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -108,6 +108,7 @@ def get_local_search_engine( covariates: dict[str, list[Covariate]], response_type: str, description_embedding_store: BaseVectorStore, + is_optimized_search: bool = False ) -> LocalSearch: """Create a local search engine based on data + configuration.""" llm = get_llm(config) @@ -128,6 +129,7 @@ def get_local_search_engine( embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE text_embedder=text_embedder, token_encoder=token_encoder, + is_optimized_search= is_optimized_search, ), token_encoder=token_encoder, llm_params={ diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index cc6ba2d150..683dad2f48 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -61,6 +61,7 @@ def __init__( covariates: dict[str, list[Covariate]] | None = None, token_encoder: tiktoken.Encoding | None = None, embedding_vectorstore_key: str = EntityVectorStoreKey.ID, + is_optimized_search: bool = False, ): if community_reports is None: community_reports = [] @@ -83,6 +84,7 @@ def __init__( self.text_embedder = text_embedder self.token_encoder = token_encoder self.embedding_vectorstore_key = embedding_vectorstore_key + self.is_optimized_search = is_optimized_search def filter_by_entity_keys(self, entity_keys: list[int] | list[str]): """Filter entity text embeddings by entity keys.""" @@ -111,7 +113,7 @@ def build_context( min_community_rank: int = 0, community_context_name: str = "Reports", column_delimiter: str = "|", - is_optimized_flow: bool = False, + is_optimized_search: bool = False, **kwargs: dict[str, Any], ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: """ @@ -173,21 +175,21 @@ def build_context( ) # build community context - if not is_optimized_flow: - community_tokens = max(int(max_tokens * community_prop), 0) - community_context, community_context_data = self._build_community_context( - selected_entities=selected_entities, - max_tokens=community_tokens, - use_community_summary=use_community_summary, - column_delimiter=column_delimiter, - include_community_rank=include_community_rank, - min_community_rank=min_community_rank, - return_candidate_context=return_candidate_context, - context_name=community_context_name, - ) - if community_context.strip() != "": - final_context.append(community_context) - final_context_data = {**final_context_data, **community_context_data} + community_tokens = max(int(max_tokens * community_prop), 0) + community_context, community_context_data = self._build_community_context( + selected_entities=selected_entities, + max_tokens=community_tokens, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + return_candidate_context=return_candidate_context, + context_name=community_context_name, + is_optimized_search=is_optimized_search + ) + if community_context.strip() != "": + final_context.append(community_context) + final_context_data = {**final_context_data, **community_context_data} # build local (i.e. entity-relationship-covariate) context local_prop = 1 - community_prop - text_unit_prop @@ -206,7 +208,7 @@ def build_context( if local_context.strip() != "": final_context.append(str(local_context)) final_context_data = {**final_context_data, **local_context_data} - if not is_optimized_flow: + if not self.is_optimized_search: # build text unit context text_unit_tokens = max(int(max_tokens * text_unit_prop), 0) text_unit_context, text_unit_context_data = self._build_text_unit_context( @@ -230,6 +232,7 @@ def _build_community_context( min_community_rank: int = 0, return_candidate_context: bool = False, context_name: str = "Reports", + is_optimized_search: bool = False, ) -> tuple[str, dict[str, pd.DataFrame]]: """Add community data to the context window until it hits the max_tokens limit.""" if len(selected_entities) == 0 or len(self.community_reports) == 0: @@ -260,46 +263,49 @@ def _build_community_context( ) for community in selected_communities: del community.attributes["matches"] # type: ignore - - context_text, context_data = build_community_context( - community_reports=selected_communities, - token_encoder=self.token_encoder, - use_community_summary=use_community_summary, - column_delimiter=column_delimiter, - shuffle_data=False, - include_community_rank=include_community_rank, - min_community_rank=min_community_rank, - max_tokens=max_tokens, - single_batch=True, - context_name=context_name, - ) - if isinstance(context_text, list) and len(context_text) > 0: - context_text = "\n\n".join(context_text) - - if return_candidate_context: - candidate_context_data = get_candidate_communities( - selected_entities=selected_entities, - community_reports=list(self.community_reports.values()), + context_data = {} + context_data["reports"] = selected_communities + context_text = "" + if not is_optimized_search: + context_text, context_data = build_community_context( + community_reports=selected_communities, + token_encoder=self.token_encoder, use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + shuffle_data=False, include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + max_tokens=max_tokens, + single_batch=True, + context_name=context_name, ) - context_key = context_name.lower() - if context_key not in context_data: - context_data[context_key] = candidate_context_data - context_data[context_key]["in_context"] = False - else: - if ( - "id" in candidate_context_data.columns - and "id" in context_data[context_key].columns - ): - candidate_context_data["in_context"] = candidate_context_data[ - "id" - ].isin( # cspell:disable-line - context_data[context_key]["id"] - ) + if isinstance(context_text, list) and len(context_text) > 0: + context_text = "\n\n".join(context_text) + + if return_candidate_context: + candidate_context_data = get_candidate_communities( + selected_entities=selected_entities, + community_reports=list(self.community_reports.values()), + use_community_summary=use_community_summary, + include_community_rank=include_community_rank, + ) + context_key = context_name.lower() + if context_key not in context_data: context_data[context_key] = candidate_context_data + context_data[context_key]["in_context"] = False else: - context_data[context_key]["in_context"] = True + if ( + "id" in candidate_context_data.columns + and "id" in context_data[context_key].columns + ): + candidate_context_data["in_context"] = candidate_context_data[ + "id" + ].isin( # cspell:disable-line + context_data[context_key]["id"] + ) + context_data[context_key] = candidate_context_data + else: + context_data[context_key]["in_context"] = True return (str(context_text), context_data) def _build_text_unit_context( @@ -392,6 +398,7 @@ def _build_local_context( relationship_ranking_attribute: str = "rank", return_candidate_context: bool = False, column_delimiter: str = "|", + is_optimized_search: bool = False ) -> tuple[str, dict[str, pd.DataFrame]]: """Build data context for local search prompt combining entity/relationship/covariate tables.""" # build entity context @@ -403,6 +410,7 @@ def _build_local_context( include_entity_rank=include_entity_rank, rank_description=rank_description, context_name="Entities", + is_optimized_search=is_optimized_search, ) entity_tokens = num_tokens(entity_context, self.token_encoder) diff --git a/graphrag/query/structured_search/local_search/search.py b/graphrag/query/structured_search/local_search/search.py index a6fe4117e1..597b511222 100644 --- a/graphrag/query/structured_search/local_search/search.py +++ b/graphrag/query/structured_search/local_search/search.py @@ -113,6 +113,7 @@ def search( **kwargs, ) -> SearchResult: """Build local search context that fits a single context window and generate answer for the user question.""" + start_time = time.time() search_prompt = "" context_text, context_records = self.context_builder.build_context( @@ -171,9 +172,9 @@ def optimized_search( context_text, context_records = self.context_builder.build_context( query=query, conversation_history=conversation_history, + is_optimized_search = self.optimized_search, **kwargs, **self.context_builder_params, - is_optimized_flow=True, ) log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query) try: From 1088da6d1fd62b91fb4ad6a6ca233ce24c173867 Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Fri, 23 Aug 2024 18:34:59 -0700 Subject: [PATCH 31/87] Fixing minor issue in main branch --- graphrag/index/input/text.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/graphrag/index/input/text.py b/graphrag/index/input/text.py index c9609d5673..fe5ae25f8e 100644 --- a/graphrag/index/input/text.py +++ b/graphrag/index/input/text.py @@ -34,15 +34,15 @@ async def load_file( ) -> dict[str, Any]: if group is None: group = {} - text = await storage.get(path.replace('\\','/'), encoding="utf-8") + text = await storage.get(path, encoding="utf-8") new_item = {**group, "text": text} new_item["id"] = gen_md5_hash(new_item, new_item.keys()) new_item["title"] = str(Path(path).name) return new_item base_dir = config.base_dir - # if config.type == "file": + if config.type == "file": # base dir is already being added to root dir in case of type file. - # base_dir = None + base_dir = None files = list( storage.find( re.compile(config.file_pattern), @@ -67,7 +67,7 @@ async def load_file( for file, group in files: try: - files_loaded.append(await load_file(base_dir + file, group)) + files_loaded.append(await load_file(file, group)) except Exception: # noqa: BLE001 (catching Exception is fine here) log.warning("Warning! Error loading file %s. Skipping...", file) From 079c0347bcc6bf7569d2384c2d10d1b8326bb20d Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Mon, 26 Aug 2024 09:46:18 -0700 Subject: [PATCH 32/87] Only using Kusto store for all entities. --- graphrag/query/cli.py | 94 +++++++++---------- .../context_builder/entity_extraction.py | 26 +++++ graphrag/vector_stores/azure_ai_search.py | 15 ++- graphrag/vector_stores/base.py | 17 +++- graphrag/vector_stores/kusto.py | 83 ++++++++++++++-- graphrag/vector_stores/lancedb.py | 15 ++- 6 files changed, 189 insertions(+), 61 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 400f2dbaf1..d783a37934 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -167,14 +167,14 @@ def run_local_search( graph_db_client = GraphDBClient(config.graphdb) for data_path in data_paths: #check from the config for the ouptut storage type and then read the data from the storage. - + #GraphDB: we may need to make change below to read nodes data from Graph DB final_nodes = pd.concat([final_nodes, read_paraquet_file(config, data_path + "/create_final_nodes.parquet", config.storage.type)]) - + final_community_reports = pd.concat([final_community_reports,read_paraquet_file(config, data_path + "/create_final_community_reports.parquet", config.storage.type)]) - + final_text_units = pd.concat([final_text_units, read_paraquet_file(config, data_path + "/create_final_text_units.parquet", config.storage.type)]) - + if config.graphdb.enabled: final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) @@ -202,6 +202,7 @@ def run_local_search( vector_store_type=vector_store_type, config_args=vector_store_args, ) + covariates = ( read_indexer_covariates(final_covariates) if final_covariates.empty is False @@ -224,7 +225,7 @@ def run_local_search( result = search_engine.search(query=query) reporter.success(f"Local Search Response: {result.response}") return result.response - + def blob_exists(container_client, blob_name): blob_client = container_client.get_blob_client(blob_name) try: @@ -253,6 +254,7 @@ def read_paraquet_file(config:GraphRagConfig, path: str, storageType: StorageTyp if not file_path.exists(): return pd.DataFrame() return pd.read_parquet(path) + # TODO I split this out for now to preserve how the original local search worked. # I don't think this will necessarily be permanently separate. # It was just easier without having to keep everything generic and work the same way as local search worked. @@ -269,14 +271,13 @@ def run_content_store_local_search( data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config_dir ) - data_path = Path(data_dir) vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} ) - + vector_store_type = vector_store_args.get("type", VectorStoreType.Kusto) - + collection_name = vector_store_args.get( "query_collection_name", "entity_description_embeddings" ) @@ -288,61 +289,51 @@ def run_content_store_local_search( description_embedding_store.connect(**vector_store_args) - #TODO add back covariates. I skipped this for now. description_embedding_store.load_parqs(data_dir, ["create_final_nodes", "create_final_community_reports", "create_final_text_units", "create_final_relationships", "create_final_entities"]) - #TODO KQLify this. This merge of nodes & entities needs to happen in Kusto. + gen_parqs = description_embedding_store.read_parqs(data_dir, ["create_final_covariates", "create_final_nodes", "create_final_community_reports", "create_final_text_units", "create_final_relationships", "create_final_entities"]) + dict_parqs = {} + for parq in gen_parqs: + dict_parqs[parq[0]] = parq[1] + final_covariates = dict_parqs.get("create_final_covariates") + final_community_reports = dict_parqs.get("create_final_community_reports") + final_text_units = dict_parqs.get("create_final_text_units") + final_relationships = dict_parqs.get("create_final_relationships") + final_nodes = dict_parqs.get("create_final_nodes") + create_entities_table(description_embedding_store, community_level) - # description_embedding_store = __get_embedding_description_store( - # entities=entities, - # description_embedding_store=description_embedding_store, - # config_args=vector_store_args, - # ) - - #TODO add back covariates w/Kusto. I skipped this for now. - # covariates = ( - # read_indexer_covariates(final_covariates) - # if final_covariates is not None - # else [] - # ) + + covariates = ( + read_indexer_covariates(final_covariates) + if final_covariates is not None + else [] + ) reports_result=kt_read_indexer_reports( description_embedding_store, community_level) - #TODO KQLify this. I know at least the read_indedxer_reports needs to be done in Kusto. We are joining the community reports & final nodes. - # search_engine = get_local_search_engine( - # config, - # reports=read_indexer_reports( - # final_community_reports, final_nodes, community_level - # ), - # text_units=read_indexer_text_units(final_text_units), - # entities=entities, - # relationships=read_indexer_relationships(final_relationships), - # covariates={"claims": covariates}, - # description_embedding_store=description_embedding_store, - # response_type=response_type, - # ) - - #TODO This is the biggest TODO. I need to go through the whole mixed_context.py and make sure it's using Kusto data not the parquet data it expects in memory. - # result = search_engine.search(query=query) - # reporter.success(f"Local Search Response: {result.response}") - # return result.response - - return True #Obviously this is a placeholder due to all the TODOs above. + search_engine = get_local_search_engine( + config, + reports=read_indexer_reports( + final_community_reports, final_nodes, community_level + ), + text_units=read_indexer_text_units(final_text_units), + entities=[], + relationships=read_indexer_relationships(final_relationships), + covariates={"claims": covariates}, + description_embedding_store=description_embedding_store, + response_type=response_type, + ) + + result = search_engine.search(query=query) + reporter.success(f"Local Search Response: {result.response}") + return result.response # Create entities table similar to read_indexer_entities, but creating that table in Kusto, not in memory. def create_entities_table(description_embedding_store: BaseVectorStore, community_level: int): description_embedding_store.execute_query(".set-or-replace entities <| ( \ create_final_nodes | where level <= 2 | project name=['title'] ,rank=degree,community | \ summarize community=max(community) by name,rank | join kind=inner \ - create_final_entities on name | project id,title=name,text=description,vector=description_embeddings)") - - ''' - description_embedding_store.execute_query(f".set entities <| create_final_nodes \ - | where level <= {community_level} \ - | project community=coalesce(community, 0), name=['title'], rank=degree \ - | summarize community=max(community) by name, rank \ - | join kind=inner create_final_entities on name") - ''' + create_final_entities on name)") def run_content_store_global_search( config_dir: str | None, @@ -358,7 +349,6 @@ def run_content_store_global_search( def _configure_paths_and_settings( - data_dir: str | None, root_dir: str | None, config_dir: str | None, diff --git a/graphrag/query/context_builder/entity_extraction.py b/graphrag/query/context_builder/entity_extraction.py index 82a0699cd8..16958e2b0a 100644 --- a/graphrag/query/context_builder/entity_extraction.py +++ b/graphrag/query/context_builder/entity_extraction.py @@ -31,6 +31,23 @@ def from_string(value: str) -> "EntityVectorStoreKey": msg = f"Invalid EntityVectorStoreKey: {value}" raise ValueError(msg) +def map_query_to_entities_in_place( + query: str, + text_embedding_vectorstore: BaseVectorStore, + text_embedder: BaseTextEmbedding, + k: int = 10, + oversample_scaler: int = 2, +) -> list[Entity]: + """Extract entities that match a given query using semantic similarity of text embeddings of query and entity descriptions.""" + # get entities with highest semantic similarity to query + # oversample to account for excluded entities + search_results = text_embedding_vectorstore.get_extracted_entities( + text=query, + text_embedder=lambda t: text_embedder.embed(t), + k=k * oversample_scaler, + ) + print(search_results) + return search_results def map_query_to_entities( query: str, @@ -44,6 +61,15 @@ def map_query_to_entities( oversample_scaler: int = 2, ) -> list[Entity]: """Extract entities that match a given query using semantic similarity of text embeddings of query and entity descriptions.""" + if all_entities == []: + return map_query_to_entities_in_place( + query, + text_embedding_vectorstore, + text_embedder, + k, + oversample_scaler, + ) + if include_entity_names is None: include_entity_names = [] if exclude_entity_names is None: diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index cd78078d35..452ad30760 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -24,6 +24,7 @@ ) from azure.search.documents.models import VectorizedQuery +from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder from .base import ( @@ -195,6 +196,18 @@ def similarity_search_by_text( def load_parqs(self, data_path, parq_names) -> Any: raise NotImplementedError("Loading Parquet files is not supported for Azure AI Search") - + + def get_extracted_entities( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + raise NotImplementedError("Extracting entities is not supported for Azure AI Search") + + def read_parqs(self, data_dir, parq_names) -> Any: + raise NotImplementedError("Reading Parquet files is not supported for Azure AI Search") + + def get_related_entities(self, titles:list[str], **kwargs: Any) -> list[Entity]: + """Get related entities from the vector store.""" + raise NotImplementedError("Getting related entities is not supported for Azure AI Search") + def execute_query(self, query: str) -> Any: return super().execute_query(query) \ No newline at end of file diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index be7a400d5c..30621c8637 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from typing import Any +from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder DEFAULT_VECTOR_SIZE: int = 1536 @@ -87,4 +88,18 @@ def load_parqs(self, data_path: str, parqs: list[str]) -> Any: #TODO This is temporary until I take out the client from the vector store class @abstractmethod def execute_query(self, query: str) -> Any: - """Execute a query in the vector-store.""" \ No newline at end of file + """Execute a query in the vector-store.""" + + @abstractmethod + def get_extracted_entities( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + """From a query, build a subtable of entities which is only matching entities.""" + + @abstractmethod + def read_parqs(self, data_dir, parq_names) -> Any: + """Return a dictionary of parquet dataframes of parq_name to data frame.""" + + @abstractmethod + def get_related_entities(self, titles: list[str], **kwargs: Any) -> list[Entity]: + """Get related entities from the vector store.""" \ No newline at end of file diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index f934231c77..421529df77 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -6,6 +6,7 @@ import typing from azure.kusto.data import KustoClient, KustoConnectionStringBuilder from azure.kusto.data.helpers import dataframe_from_result_table +from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder import pandas as pd @@ -31,7 +32,7 @@ class KustoVectorStore(BaseVectorStore): , "create_final_community_reports": "(community: int, full_content: string, level: int, rank: int, title: string, rank_explanation: string, summary: string, findings: string, full_content_json: string, id: string)" , "create_final_text_units": "(id: string, text: string, n_tokens: int, document_ids: string, entity_ids: string, relationship_ids: string)" , "create_final_relationships": "(source: string, target: string, weight: real, description: string, text_unit_ids: string, id: string, human_readable_id: string, source_degree: int, target_degree: int, rank: int)" - , "create_final_entities": "(id: string, name: string, type: string, description: string, human_readable_id: int, graph_embedding: dynamic, text_unit_ids: string, description_embeddings: dynamic )"} + , "create_final_entities": "(id: string, name: string, type: string, description: string, human_readable_id: int, graph_embedding: dynamic, text_unit_ids: string, description_embedding: dynamic)"} def connect(self, **kwargs: Any) -> Any: """ @@ -125,7 +126,7 @@ def similarity_search_by_vector( self, query_embedding: List[float], k: int = 10, **kwargs: Any ) -> List[VectorStoreSearchResult]: """ - Perform a vector-based similarity search. + Perform a vector-based similarity search. A search to find the k nearest neighbors of the given query vector. Args: query_embedding (List[float]): The query embedding vector. @@ -139,7 +140,6 @@ def similarity_search_by_vector( let query_vector = dynamic({query_embedding}); {self.collection_name} | extend distance = array_length(set_difference(vector, query_vector)) - | where distance <= {k} | top {k} by distance asc """ response = self.client.execute(self.database, query) @@ -160,7 +160,7 @@ def similarity_search_by_vector( def similarity_search_by_text( self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any - ) -> List[VectorStoreSearchResult]: + ) -> list[VectorStoreSearchResult]: """ Perform a similarity search using a given input text. @@ -178,8 +178,69 @@ def similarity_search_by_text( return self.similarity_search_by_vector(query_embedding, k) return [] + def execute_query(self, query: str) -> Any: - self.client.execute(self.database, f"{query}") + return self.client.execute(self.database, f"{query}") + + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + query_results = self.similarity_search_by_text(text, text_embedder, k) + query_ids = [result.document.id for result in query_results] + if query_ids not in [[], None]: + ids_str = "\", \"".join([str(id) for id in query_ids]) + query = f""" + entities + | where id in ("{ids_str}") + """ + print(query) + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + return self.__extract_entities_from_data_frame(df) + return [] + + def __extract_entities_from_data_frame(self, df: pd.DataFrame) -> list[Entity]: + return [ + Entity( + id=row["id"], + title=row["name1"], + type=row["type"], + description=row["description"], + graph_embedding=row["graph_embedding"], + text_unit_ids=row["text_unit_ids"], + description_embedding=row["description_embedding"], + short_id="", + community_ids=[row["community"]], + rank=row["rank"], + attributes={"title":row["name1"]}, + ) + for _, row in df.iterrows() + ] + + def get_related_entities(self, titles: list[str], **kwargs: Any) -> list[Entity]: + """Get related entities based on the given titles.""" + titles_str = "\", \"".join(titles) + + query = f""" + create_final_relationships + | where source in ("{titles_str}") + | project name=target + | join kind=inner create_final_entities on name + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + selected_entities = self.__extract_entities_from_data_frame(df) + + query = f""" + create_final_relationships + | where target in ("{titles_str}") + | project name=source + | join kind=inner create_final_entities on name + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + selected_entities += self.__extract_entities_from_data_frame(df) + + return selected_entities def load_parqs(self, data_dir, parq_names) -> Any: data_path = Path(data_dir) @@ -202,4 +263,14 @@ def load_parqs(self, data_dir, parq_names) -> Any: command = f".ingest inline into table {parq_name} <| {parq.to_csv(index=False, header=False)}" self.client.execute(self.database, command) else: - print(f"Parquet file {parq_path} not found.") \ No newline at end of file + print(f"Parquet file {parq_path} not found.") + + def read_parqs(self, data_dir, parq_names) -> Any: + """Return a dictionary of parquet dataframes of parq_name to data frame.""" + data_path = Path(data_dir) + for parq_name in parq_names: + parq_path = data_path / f"{parq_name}.parquet" + parq = None + if parq_path.exists(): + parq = pd.read_parquet(parq_path) + yield parq_name, parq diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index ddf6effb59..021312500c 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -4,6 +4,7 @@ """The LanceDB vector storage implementation package.""" import lancedb as lancedb # noqa: I001 (Ruff was breaking on this file imports, even tho they were sorted and passed local tests) +from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder import json @@ -122,6 +123,18 @@ def similarity_search_by_text( def load_parqs(self, data_path, parq_names) -> Any: raise NotImplementedError("Loading Parquet files is not supported for LanceDB") - + + def get_extracted_entities( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + raise NotImplementedError("Extracting entities is not supported for LanceDB") + + def read_parqs(self, data_dir, parq_names) -> Any: + raise NotImplementedError("Reading Parquet files is not supported for LanceDB") + + def get_related_entities(self, titles: list[str], **kwargs: Any) -> list[Entity]: + """Get related entities from the vector store.""" + raise NotImplementedError("Getting related entities is not supported for LanceDB") + def execute_query(self, query: str) -> Any: return super().execute_query(query) \ No newline at end of file From 28594c6a4a2ef4e03e6ad087b8a990b59c41b179 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Mon, 26 Aug 2024 15:25:33 -0700 Subject: [PATCH 33/87] Include context into read and write calls for graphdb --- common/graph_db_client.py | 98 +++++++++++++++++++++++++++------------ graphrag/query/cli.py | 4 +- 2 files changed, 70 insertions(+), 32 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index db1d467df2..483c49d2f6 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -14,10 +14,12 @@ class GraphDBClient: def __init__(self,graph_db_params: GraphDBConfig|None): + self.username_prefix=graph_db_params.username + self.current_context="00000000-0000-0000-0000-000000000000" self._client=client.Client( url=f"wss://{graph_db_params.account_name}.gremlin.cosmos.azure.com:443/", traversal_source="g", - username=graph_db_params.username, + username=self.username_prefix+"-contextid-"+self.current_context, password=f"{graph_db_params.account_key}", message_serializer=serializer.GraphSONSerializersV2d0(), ) @@ -42,7 +44,7 @@ def result_to_df(self,result) -> pd.DataFrame: df = pd.DataFrame(json_data) return df - def query_vertices(self) -> pd.DataFrame: + def query_vertices(self,context_id:str="00000000-0000-0000-0000-000000000000") -> pd.DataFrame: result = self._client.submit( message=( "g.V()" @@ -50,7 +52,7 @@ def query_vertices(self) -> pd.DataFrame: ) return self.result_to_df(result) - def query_edges(self) -> pd.DataFrame: + def query_edges(self,context_id:str="00000000-0000-0000-0000-000000000000") -> pd.DataFrame: result = self._client.submit( message=( "g.E()" @@ -58,40 +60,76 @@ def query_edges(self) -> pd.DataFrame: ) return self.result_to_df(result) - def write_vertices(self,data: pd.DataFrame)->None: - for row in data.itertuples(): - print(row.id) - self._client.submit( + def element_exists(self,element_type:str,element_id:int,conditions:str="")->bool: + result=self._client.submit( message=( - "g.addV('entity')" - ".property('id', prop_id)" - ".property('name', prop_name)" - ".property('type', prop_type)" - ".property('description','prop_description')" - ".property('human_readable_id', prop_human_readable_id)" - ".property('category', prop_partition_key)" - ".property(list,'description_embedding',prop_description_embedding)" - ".property(list,'graph_embedding',prop_graph_embedding)" - ".property(list,'text_unit_ids',prop_text_unit_ids)" + element_type+ + ".has('id',prop_id)"+ + conditions+ + ".count()" ), bindings={ - "prop_id": row.id, - "prop_name": row.name, - "prop_type": row.type, - "prop_description": row.description, - "prop_human_readable_id": row.human_readable_id, - "prop_partition_key": "entities", - "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), - "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), - "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), - }, - ) + "prop_id":element_id, + } + ) + element_count=0 + for counts in result: + element_count=counts[0] + return element_count>0 + + def switch_graphs(self,context_id:str)->None: + if context_id==self.current_context: + return + self.current_context=context_id + updated_client=client.Client( + url=self._client._url, + traversal_source="g", + username=self.username_prefix+"-contextid-"+self.current_context, + password=self._client._password, + message_serializer=serializer.GraphSONSerializersV2d0(), + ) + self._client.close() + self._client=updated_client + + def write_vertices(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: + self.switch_graphs(context_id) + for row in data.itertuples(): + if self.element_exists("g.V()",row.id): + continue + else: + self._client.submit( + message=( + "g.addV('entity')" + ".property('id', prop_id)" + ".property('name', prop_name)" + ".property('type', prop_type)" + ".property('description','prop_description')" + ".property('human_readable_id', prop_human_readable_id)" + ".property('category', prop_partition_key)" + ".property(list,'description_embedding',prop_description_embedding)" + ".property(list,'graph_embedding',prop_graph_embedding)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ), + bindings={ + "prop_id": row.id, + "prop_name": row.name, + "prop_type": row.type, + "prop_description": row.description, + "prop_human_readable_id": row.human_readable_id, + "prop_partition_key": "entities", + "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), + "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), + "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), + }, + ) time.sleep(5) - def write_edges(self,data: pd.DataFrame)->None: + def write_edges(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: + self.switch_graphs(context_id) for row in data.itertuples(): - print(row.source,row.target) + if self.element_exists("g.E()",row.id): + continue self._client.submit( message=( "g.V().has('name',prop_source_id)" diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index d783a37934..1765c7c18f 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -176,8 +176,8 @@ def run_local_search( final_text_units = pd.concat([final_text_units, read_paraquet_file(config, data_path + "/create_final_text_units.parquet", config.storage.type)]) if config.graphdb.enabled: - final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) - final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) + final_relationships = pd.concat([final_relationships, graph_db_client.query_edges(context_id)]) + final_entities = pd.concat([final_entities, graph_db_client.query_vertices(context_id)]) else: final_relationships = pd.concat([final_relationships, read_paraquet_file(config, data_path + "/create_final_relationships.parquet", config.storage.type)]) final_entities = pd.concat([final_entities, read_paraquet_file(config, data_path + "/create_final_entities.parquet", config.storage.type)]) From aefa92419d7f613a0fd0db2b42d0c6e1f6105137 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Mon, 26 Aug 2024 15:51:52 -0700 Subject: [PATCH 34/87] Using create_final_entities table rather than entity_description_embeddings & issue with to_csv not handling floats properly --- graphrag/index/create_pipeline_config.py | 1 + graphrag/query/__main__.py | 16 +++--- graphrag/query/cli.py | 20 +++++-- .../context_builder/entity_extraction.py | 1 - .../local_search/mixed_context.py | 2 + graphrag/vector_stores/base.py | 2 + graphrag/vector_stores/kusto.py | 53 +++++++++++++++++-- 7 files changed, 76 insertions(+), 19 deletions(-) diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index 9710816444..86c43f5d8c 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -352,6 +352,7 @@ def _graph_workflows( { "title_column": "description", "collection_name": "entity_description_embeddings", + "vector_name": "vector", }, ), "skip_name_embedding": skip_entity_name_embedding, diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index c6b34a78f4..764e27383d 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -6,7 +6,7 @@ import argparse from enum import Enum -from .cli import run_global_search, run_local_search, run_content_store_local_search, run_content_store_global_search +from .cli import run_global_search, run_local_search, run_kusto_local_search, run_kusto_global_search INVALID_METHOD_ERROR = "Invalid method" @@ -16,8 +16,8 @@ class SearchType(Enum): LOCAL = "local" GLOBAL = "global" - CONTENT_STORE_LOCAL = "content_store_local" - CONTENT_STORE_GLOBAL = "content_store_global" + KUSTO_LOCAL = "kusto_local" + KUSTO_GLOBAL = "kusto_global" def __str__(self): """Return the string representation of the enum value.""" @@ -85,7 +85,7 @@ def __str__(self): type=str, ) - + args = parser.parse_args() @@ -110,8 +110,8 @@ def __str__(self): args.context_id, args.query[0], ) - case SearchType.CONTENT_STORE_LOCAL: - run_content_store_local_search( + case SearchType.KUSTO_LOCAL: + run_kusto_local_search( args.config, args.data, args.root, @@ -119,8 +119,8 @@ def __str__(self): args.response_type, args.query[0], ) - case SearchType.CONTENT_STORE_GLOBAL: - run_content_store_global_search( + case SearchType.KUSTO_GLOBAL: + run_kusto_global_search( args.config, args.data, args.root, diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index d783a37934..633836f1a9 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -57,6 +57,11 @@ def __get_embedding_description_store( "query_collection_name", "entity_description_embeddings" ) config_args.update({"collection_name": collection_name}) + vector_name = config_args.get( + "vector_search_column", "vector" + ) + config_args.update({"vector_name": vector_name}) + description_embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=vector_store_type, kwargs=config_args ) @@ -259,7 +264,7 @@ def read_paraquet_file(config:GraphRagConfig, path: str, storageType: StorageTyp # I don't think this will necessarily be permanently separate. # It was just easier without having to keep everything generic and work the same way as local search worked. # One last optimization: Once all the merges are done we can go back to the parquet loads and optimize those for only the fields we need and merge them right away into one big table (I think). -def run_content_store_local_search( +def run_kusto_local_search( config_dir: str | None, data_dir: str | None, root_dir: str | None, @@ -267,11 +272,12 @@ def run_content_store_local_search( response_type: str, query: str, ): - """Run a local search with the given query.""" + """Run a local search in Kusto with the given query.""" data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config_dir ) + vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} ) @@ -279,9 +285,13 @@ def run_content_store_local_search( vector_store_type = vector_store_args.get("type", VectorStoreType.Kusto) collection_name = vector_store_args.get( - "query_collection_name", "entity_description_embeddings" + "query_collection_name", "entities" ) vector_store_args.update({"collection_name": collection_name}) + vector_name = vector_store_args.get( + "vector_search_column", "description_embedding" + ) + vector_store_args.update({"vector_name": vector_name}) description_embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=vector_store_type, kwargs=vector_store_args @@ -335,7 +345,7 @@ def create_entities_table(description_embedding_store: BaseVectorStore, communit summarize community=max(community) by name,rank | join kind=inner \ create_final_entities on name)") -def run_content_store_global_search( +def run_kusto_global_search( config_dir: str | None, data_dir: str | None, root_dir: str | None, @@ -343,7 +353,7 @@ def run_content_store_global_search( response_type: str, query: str, ): - """Run a content store global search with the given query.""" + """Run a global search in Kusto with the given query.""" raise NotImplementedError("This function is not implemented yet.") diff --git a/graphrag/query/context_builder/entity_extraction.py b/graphrag/query/context_builder/entity_extraction.py index 16958e2b0a..fe8253b612 100644 --- a/graphrag/query/context_builder/entity_extraction.py +++ b/graphrag/query/context_builder/entity_extraction.py @@ -46,7 +46,6 @@ def map_query_to_entities_in_place( text_embedder=lambda t: text_embedder.embed(t), k=k * oversample_scaler, ) - print(search_results) return search_results def map_query_to_entities( diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index 117101e913..5fe252e9f8 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -148,6 +148,8 @@ def build_context( oversample_scaler=2, ) + print("Selected entities titles: ", [entity.title for entity in selected_entities]) + # build context final_context = list[str]() final_context_data = dict[str, pd.DataFrame]() diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index 30621c8637..a02fc834da 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -44,12 +44,14 @@ class BaseVectorStore(ABC): def __init__( self, collection_name: str, + vector_name: str, db_connection: Any | None = None, document_collection: Any | None = None, query_filter: Any | None = None, **kwargs: Any, ): self.collection_name = collection_name + self.vector_name = vector_name self.db_connection = db_connection self.document_collection = document_collection self.query_filter = query_filter diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 421529df77..6dc1882ac8 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -15,6 +15,16 @@ import json from typing import Any, List, cast +from graphrag.query.input.loaders.utils import ( + to_list, + to_optional_dict, + to_optional_float, + to_optional_int, + to_optional_list, + to_optional_str, + to_str, +) + from .base import ( BaseVectorStore, VectorStoreDocument, @@ -74,7 +84,7 @@ def load_documents( data = [ { "id": document.id, - "text": document.text, + "name": document.text, "vector": document.vector, "attributes": json.dumps(document.attributes), } @@ -139,19 +149,35 @@ def similarity_search_by_vector( query = f""" let query_vector = dynamic({query_embedding}); {self.collection_name} - | extend distance = array_length(set_difference(vector, query_vector)) + | extend distance = array_length(set_difference({self.vector_name}, query_vector)) | top {k} by distance asc """ response = self.client.execute(self.database, query) df = dataframe_from_result_table(response.primary_results[0]) + print("Distances of the search results:", [row["distance"] for _, row in df.iterrows()]) + + # Temporary to support the original entity_description_embedding + if(self.vector_name == "vector"): + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=row["id"], + text=row["text"], + vector=row[self.vector_name], + attributes=row["attributes"], + ), + score=1 - abs(float(row["distance"])), + ) + for _, row in df.iterrows() + ] return [ VectorStoreSearchResult( document=VectorStoreDocument( id=row["id"], - text=row["text"], - vector=row["vector"], - attributes=json.loads(row["attributes"]), + text=row["name"], + vector=row[self.vector_name], + attributes={"title":row["name"]}, ), score=1 - abs(float(row["distance"])), ) @@ -260,6 +286,23 @@ def load_parqs(self, data_dir, parq_names) -> Any: self.client.execute(self.database, command) command = f".create table {parq_name} {self.schema_dict[parq_name]}" self.client.execute(self.database, command) + + # Due to an issue with to_csv not being able to handle float64, I had to manually handle entities. + if parq_name == "create_final_entities": + data = [ + { + "id": to_str(row, "id"), + "name": to_str(row, "name"), + "type": to_optional_str(row, "type"), + "description": to_optional_str(row, "description"), + "human_readable_id": to_optional_str(row, "human_readable_id"), + "graph_embedding": to_optional_list(row, "graph_embedding"), + "text_unit_ids": to_optional_list(row, "text_unit_ids"), + "description_embedding": to_optional_list(row, "description_embedding"), + } + for idx, row in parq.iterrows() + ] + parq = pd.DataFrame(data) command = f".ingest inline into table {parq_name} <| {parq.to_csv(index=False, header=False)}" self.client.execute(self.database, command) else: From 8a1d17c1ca0df7b514ce927000f7ce5d35ee2b2d Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Mon, 26 Aug 2024 17:22:02 -0700 Subject: [PATCH 35/87] Optimized Search --- .../index/context_switch/contextSwitcher.py | 3 +-- .../query/context_builder/local_context.py | 18 ++++++++++++------ .../local_search/mixed_context.py | 4 ++++ 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index 7ff1fa023f..f992eb6354 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -1,4 +1,4 @@ -from graphrag.index.progress import ProgressReporter +from graphrag.common.progress import ProgressReporter from graphrag.config import GraphRagConfig class ContextSwitcher: @@ -9,7 +9,6 @@ def __init__(self): pass def activate(self, config: GraphRagConfig | str, contextId: str | None, reporter: ProgressReporter): """Activate the context.""" - #1. read the context id to fileId mapping. #2. read the file from storage using common/blob_storage_client.py #3. GraphDB: use cosmos db client to load data into Cosmos DB. diff --git a/graphrag/query/context_builder/local_context.py b/graphrag/query/context_builder/local_context.py index 89c2559562..bd1790980b 100644 --- a/graphrag/query/context_builder/local_context.py +++ b/graphrag/query/context_builder/local_context.py @@ -69,12 +69,13 @@ def build_entity_context( else "" ) new_context.append(field_value) + new_tokens: int = 0 if not is_optimized_search: new_context_text = column_delimiter.join(new_context) + "\n" new_tokens = num_tokens(new_context_text, token_encoder) if current_tokens + new_tokens > max_tokens: break - current_context_text += new_context_text + current_context_text += new_context_text all_context_records.append(new_context) current_tokens += new_tokens @@ -95,6 +96,7 @@ def build_covariates_context( max_tokens: int = 8000, column_delimiter: str = "|", context_name: str = "Covariates", + is_optimized_search: bool = False ) -> tuple[str, pd.DataFrame]: """Prepare covariate data tables as context data for system prompt.""" # create an empty list of covariates @@ -162,6 +164,7 @@ def build_relationship_context( relationship_ranking_attribute: str = "rank", column_delimiter: str = "|", context_name: str = "Relationships", + is_optimized_search: bool = False ) -> tuple[str, pd.DataFrame]: """Prepare relationship data tables as context data for system prompt.""" selected_relationships = _filter_relationships( @@ -207,11 +210,14 @@ def build_relationship_context( else "" ) new_context.append(field_value) - new_context_text = column_delimiter.join(new_context) + "\n" - new_tokens = num_tokens(new_context_text, token_encoder) - if current_tokens + new_tokens > max_tokens: - break - current_context_text += new_context_text + new_context_text = "" + new_tokens = 0 + if not is_optimized_search: + new_context_text = column_delimiter.join(new_context) + "\n" + new_tokens = num_tokens(new_context_text, token_encoder) + if current_tokens + new_tokens > max_tokens: #General: There could be side impact of generating huge number of relationships + break + current_context_text += new_context_text all_context_records.append(new_context) current_tokens += new_tokens diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index 683dad2f48..177da94919 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -204,6 +204,7 @@ def build_context( relationship_ranking_attribute=relationship_ranking_attribute, return_candidate_context=return_candidate_context, column_delimiter=column_delimiter, + is_optimized_search=is_optimized_search ) if local_context.strip() != "": final_context.append(str(local_context)) @@ -439,6 +440,7 @@ def _build_local_context( include_relationship_weight=include_relationship_weight, relationship_ranking_attribute=relationship_ranking_attribute, context_name="Relationships", + is_optimized_search=is_optimized_search ) current_context.append(relationship_context) current_context_data["relationships"] = relationship_context_data @@ -446,6 +448,7 @@ def _build_local_context( relationship_context, self.token_encoder ) + # build covariate context for covariate in self.covariates: covariate_context, covariate_context_data = build_covariates_context( @@ -455,6 +458,7 @@ def _build_local_context( max_tokens=max_tokens, column_delimiter=column_delimiter, context_name=covariate, + is_optimized_search=is_optimized_search ) total_tokens += num_tokens(covariate_context, self.token_encoder) current_context.append(covariate_context) From 1c2137645b374c8f5551901345852345cc314b57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Mon, 26 Aug 2024 15:25:33 -0700 Subject: [PATCH 36/87] Include context into read and write calls for graphdb --- common/graph_db_client.py | 98 +++++++++++++++++++++++++++------------ graphrag/query/cli.py | 4 +- 2 files changed, 70 insertions(+), 32 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index db1d467df2..483c49d2f6 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -14,10 +14,12 @@ class GraphDBClient: def __init__(self,graph_db_params: GraphDBConfig|None): + self.username_prefix=graph_db_params.username + self.current_context="00000000-0000-0000-0000-000000000000" self._client=client.Client( url=f"wss://{graph_db_params.account_name}.gremlin.cosmos.azure.com:443/", traversal_source="g", - username=graph_db_params.username, + username=self.username_prefix+"-contextid-"+self.current_context, password=f"{graph_db_params.account_key}", message_serializer=serializer.GraphSONSerializersV2d0(), ) @@ -42,7 +44,7 @@ def result_to_df(self,result) -> pd.DataFrame: df = pd.DataFrame(json_data) return df - def query_vertices(self) -> pd.DataFrame: + def query_vertices(self,context_id:str="00000000-0000-0000-0000-000000000000") -> pd.DataFrame: result = self._client.submit( message=( "g.V()" @@ -50,7 +52,7 @@ def query_vertices(self) -> pd.DataFrame: ) return self.result_to_df(result) - def query_edges(self) -> pd.DataFrame: + def query_edges(self,context_id:str="00000000-0000-0000-0000-000000000000") -> pd.DataFrame: result = self._client.submit( message=( "g.E()" @@ -58,40 +60,76 @@ def query_edges(self) -> pd.DataFrame: ) return self.result_to_df(result) - def write_vertices(self,data: pd.DataFrame)->None: - for row in data.itertuples(): - print(row.id) - self._client.submit( + def element_exists(self,element_type:str,element_id:int,conditions:str="")->bool: + result=self._client.submit( message=( - "g.addV('entity')" - ".property('id', prop_id)" - ".property('name', prop_name)" - ".property('type', prop_type)" - ".property('description','prop_description')" - ".property('human_readable_id', prop_human_readable_id)" - ".property('category', prop_partition_key)" - ".property(list,'description_embedding',prop_description_embedding)" - ".property(list,'graph_embedding',prop_graph_embedding)" - ".property(list,'text_unit_ids',prop_text_unit_ids)" + element_type+ + ".has('id',prop_id)"+ + conditions+ + ".count()" ), bindings={ - "prop_id": row.id, - "prop_name": row.name, - "prop_type": row.type, - "prop_description": row.description, - "prop_human_readable_id": row.human_readable_id, - "prop_partition_key": "entities", - "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), - "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), - "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), - }, - ) + "prop_id":element_id, + } + ) + element_count=0 + for counts in result: + element_count=counts[0] + return element_count>0 + + def switch_graphs(self,context_id:str)->None: + if context_id==self.current_context: + return + self.current_context=context_id + updated_client=client.Client( + url=self._client._url, + traversal_source="g", + username=self.username_prefix+"-contextid-"+self.current_context, + password=self._client._password, + message_serializer=serializer.GraphSONSerializersV2d0(), + ) + self._client.close() + self._client=updated_client + + def write_vertices(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: + self.switch_graphs(context_id) + for row in data.itertuples(): + if self.element_exists("g.V()",row.id): + continue + else: + self._client.submit( + message=( + "g.addV('entity')" + ".property('id', prop_id)" + ".property('name', prop_name)" + ".property('type', prop_type)" + ".property('description','prop_description')" + ".property('human_readable_id', prop_human_readable_id)" + ".property('category', prop_partition_key)" + ".property(list,'description_embedding',prop_description_embedding)" + ".property(list,'graph_embedding',prop_graph_embedding)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ), + bindings={ + "prop_id": row.id, + "prop_name": row.name, + "prop_type": row.type, + "prop_description": row.description, + "prop_human_readable_id": row.human_readable_id, + "prop_partition_key": "entities", + "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), + "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), + "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), + }, + ) time.sleep(5) - def write_edges(self,data: pd.DataFrame)->None: + def write_edges(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: + self.switch_graphs(context_id) for row in data.itertuples(): - print(row.source,row.target) + if self.element_exists("g.E()",row.id): + continue self._client.submit( message=( "g.V().has('name',prop_source_id)" diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 633836f1a9..609de371f8 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -181,8 +181,8 @@ def run_local_search( final_text_units = pd.concat([final_text_units, read_paraquet_file(config, data_path + "/create_final_text_units.parquet", config.storage.type)]) if config.graphdb.enabled: - final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) - final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) + final_relationships = pd.concat([final_relationships, graph_db_client.query_edges(context_id)]) + final_entities = pd.concat([final_entities, graph_db_client.query_vertices(context_id)]) else: final_relationships = pd.concat([final_relationships, read_paraquet_file(config, data_path + "/create_final_relationships.parquet", config.storage.type)]) final_entities = pd.concat([final_entities, read_paraquet_file(config, data_path + "/create_final_entities.parquet", config.storage.type)]) From ca0dcc9deb7cd0faee9b195ee3f9c3c33e8c6a02 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Tue, 27 Aug 2024 12:10:22 -0700 Subject: [PATCH 37/87] Changing similarity search (query to entity embedding search) to use cosine similarity and setting encoding of embedded vectors for entities to vector16 --- graphrag/vector_stores/kusto.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 6dc1882ac8..562783d6c9 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -149,12 +149,12 @@ def similarity_search_by_vector( query = f""" let query_vector = dynamic({query_embedding}); {self.collection_name} - | extend distance = array_length(set_difference({self.vector_name}, query_vector)) - | top {k} by distance asc + | extend similarity = series_cosine_similarity(query_vector, {self.vector_name}) + | top {k} by similarity desc """ response = self.client.execute(self.database, query) df = dataframe_from_result_table(response.primary_results[0]) - print("Distances of the search results:", [row["distance"] for _, row in df.iterrows()]) + print("Similarities of the search results:", [row["similarity"] for _, row in df.iterrows()]) # Temporary to support the original entity_description_embedding if(self.vector_name == "vector"): @@ -166,7 +166,7 @@ def similarity_search_by_vector( vector=row[self.vector_name], attributes=row["attributes"], ), - score=1 - abs(float(row["distance"])), + score=float(row["similarity"]), ) for _, row in df.iterrows() ] @@ -179,7 +179,7 @@ def similarity_search_by_vector( vector=row[self.vector_name], attributes={"title":row["name"]}, ), - score=1 - abs(float(row["distance"])), + score=float(row["similarity"]), ) for _, row in df.iterrows() ] @@ -289,6 +289,8 @@ def load_parqs(self, data_dir, parq_names) -> Any: # Due to an issue with to_csv not being able to handle float64, I had to manually handle entities. if parq_name == "create_final_entities": + command = f".alter column create_final_entities.graph_embedding policy encoding type = 'Vector16'" + command = f".alter column create_final_entities.description_embedding policy encoding type = 'Vector16'" data = [ { "id": to_str(row, "id"), From a87071df9d44d6bdfef264418791a3f3cf62e29b Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Tue, 27 Aug 2024 14:52:16 -0700 Subject: [PATCH 38/87] Working changes --- graphrag/query/cli.py | 12 +++---- .../local_search/mixed_context.py | 32 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 0645635289..4bbba60474 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -181,11 +181,12 @@ def run_local_search( #GraphDB: we may need to make change below to read nodes data from Graph DB final_nodes = pd.concat([final_nodes, read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet")]) + final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto + final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. - final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) - - final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) - + if not optimized_search: + final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) + if config.graphdb.enabled: final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) @@ -193,7 +194,6 @@ def run_local_search( final_relationships = pd.concat([final_relationships, read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) - final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} @@ -202,7 +202,7 @@ def run_local_search( reporter.info(f"Vector Store Args: {vector_store_args}") vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) # verify kusto vector store here. - entities = read_indexer_entities(final_nodes, final_entities, community_level) # Change it to read file specific indexer files. + entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. description_embedding_store = __get_embedding_description_store( entities=entities, vector_store_type=vector_store_type, diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index 177da94919..44db158575 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -174,22 +174,22 @@ def build_context( conversation_history_context, self.token_encoder ) - # build community context - community_tokens = max(int(max_tokens * community_prop), 0) - community_context, community_context_data = self._build_community_context( - selected_entities=selected_entities, - max_tokens=community_tokens, - use_community_summary=use_community_summary, - column_delimiter=column_delimiter, - include_community_rank=include_community_rank, - min_community_rank=min_community_rank, - return_candidate_context=return_candidate_context, - context_name=community_context_name, - is_optimized_search=is_optimized_search - ) - if community_context.strip() != "": - final_context.append(community_context) - final_context_data = {**final_context_data, **community_context_data} + if not is_optimized_search: + community_tokens = max(int(max_tokens * community_prop), 0) + community_context, community_context_data = self._build_community_context( + selected_entities=selected_entities, + max_tokens=community_tokens, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + return_candidate_context=return_candidate_context, + context_name=community_context_name, + is_optimized_search=is_optimized_search + ) + if community_context.strip() != "": + final_context.append(community_context) + final_context_data = {**final_context_data, **community_context_data} # build local (i.e. entity-relationship-covariate) context local_prop = 1 - community_prop - text_unit_prop From 8bc2c71c30a891a0af877d5aa27398415f839f8b Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Wed, 28 Aug 2024 10:33:17 -0700 Subject: [PATCH 39/87] Kusto minor edits --- graphrag/vector_stores/kusto.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 562783d6c9..3d135f06b0 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -166,7 +166,7 @@ def similarity_search_by_vector( vector=row[self.vector_name], attributes=row["attributes"], ), - score=float(row["similarity"]), + score= 1 + float(row["similarity"]), ) for _, row in df.iterrows() ] @@ -179,7 +179,7 @@ def similarity_search_by_vector( vector=row[self.vector_name], attributes={"title":row["name"]}, ), - score=float(row["similarity"]), + score= 1 + float(row["similarity"]), # get a [0,2] range; work with positvie numbers ) for _, row in df.iterrows() ] @@ -290,7 +290,9 @@ def load_parqs(self, data_dir, parq_names) -> Any: # Due to an issue with to_csv not being able to handle float64, I had to manually handle entities. if parq_name == "create_final_entities": command = f".alter column create_final_entities.graph_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) command = f".alter column create_final_entities.description_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) data = [ { "id": to_str(row, "id"), From 1927e99e0453678ffa8efed9e76ce7111691667c Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Wed, 28 Aug 2024 13:48:48 -0700 Subject: [PATCH 40/87] Merging Kusto local search into local search --- graphrag/query/__main__.py | 20 +-- graphrag/query/cli.py | 109 +-------------- graphrag/vector_stores/azure_ai_search.py | 17 +-- graphrag/vector_stores/base.py | 17 +-- graphrag/vector_stores/kusto.py | 160 +++++----------------- graphrag/vector_stores/lancedb.py | 17 +-- 6 files changed, 49 insertions(+), 291 deletions(-) diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index 875b886ffa..4ee502bd89 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -6,7 +6,7 @@ import argparse from enum import Enum -from .cli import run_global_search, run_local_search, run_kusto_local_search, run_kusto_global_search +from .cli import run_global_search, run_local_search INVALID_METHOD_ERROR = "Invalid method" @@ -118,23 +118,5 @@ def __str__(self): args.context_id, args.query[0], ) - case SearchType.KUSTO_LOCAL: - run_kusto_local_search( - args.config, - args.data, - args.root, - args.community_level, - args.response_type, - args.query[0], - ) - case SearchType.KUSTO_GLOBAL: - run_kusto_global_search( - args.config, - args.data, - args.root, - args.community_level, - args.response_type, - args.query[0], - ) case _: raise ValueError(INVALID_METHOD_ERROR) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 9c4c569ac8..d67e79f4ae 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -58,7 +58,7 @@ def __get_embedding_description_store( ) config_args.update({"collection_name": collection_name}) vector_name = config_args.get( - "vector_search_column", "vector" + "vector_search_column", "description_embedding" ) config_args.update({"vector_name": vector_name}) @@ -68,7 +68,10 @@ def __get_embedding_description_store( description_embedding_store.connect(**config_args) - if config_args.get("overwrite", True): + if vector_store_type == VectorStoreType.Kusto: + description_embedding_store.load_entities(entities) + + elif config_args.get("overwrite", True): # this step assumps the embeddings where originally stored in a file rather # than a vector database @@ -188,7 +191,7 @@ def run_local_search( final_nodes = pd.concat([final_nodes, read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet")]) final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. - + if not optimized_search: final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) @@ -226,7 +229,7 @@ def run_local_search( final_community_reports, final_nodes, community_level ), text_units=read_indexer_text_units(final_text_units), - entities=entities, + entities=[], relationships=read_indexer_relationships(final_relationships), covariates={"claims": covariates}, description_embedding_store=description_embedding_store, @@ -260,104 +263,6 @@ def read_paraquet_file(storage: PipelineStorage, path: str): if file_data is None: return pd.DataFrame() return pd.read_parquet(BytesIO(file_data), engine="pyarrow") - -# TODO I split this out for now to preserve how the original local search worked. -# I don't think this will necessarily be permanently separate. -# It was just easier without having to keep everything generic and work the same way as local search worked. -# One last optimization: Once all the merges are done we can go back to the parquet loads and optimize those for only the fields we need and merge them right away into one big table (I think). -def run_kusto_local_search( - config_dir: str | None, - data_dir: str | None, - root_dir: str | None, - community_level: int, - response_type: str, - query: str, -): - """Run a local search in Kusto with the given query.""" - data_dir, root_dir, config = _configure_paths_and_settings( - data_dir, root_dir, config_dir - ) - - - vector_store_args = ( - config.embeddings.vector_store if config.embeddings.vector_store else {} - ) - - vector_store_type = vector_store_args.get("type", VectorStoreType.Kusto) - - collection_name = vector_store_args.get( - "query_collection_name", "entities" - ) - vector_store_args.update({"collection_name": collection_name}) - vector_name = vector_store_args.get( - "vector_search_column", "description_embedding" - ) - vector_store_args.update({"vector_name": vector_name}) - - description_embedding_store = VectorStoreFactory.get_vector_store( - vector_store_type=vector_store_type, kwargs=vector_store_args - ) - - description_embedding_store.connect(**vector_store_args) - - description_embedding_store.load_parqs(data_dir, ["create_final_nodes", "create_final_community_reports", "create_final_text_units", "create_final_relationships", "create_final_entities"]) - - gen_parqs = description_embedding_store.read_parqs(data_dir, ["create_final_covariates", "create_final_nodes", "create_final_community_reports", "create_final_text_units", "create_final_relationships", "create_final_entities"]) - dict_parqs = {} - for parq in gen_parqs: - dict_parqs[parq[0]] = parq[1] - final_covariates = dict_parqs.get("create_final_covariates") - final_community_reports = dict_parqs.get("create_final_community_reports") - final_text_units = dict_parqs.get("create_final_text_units") - final_relationships = dict_parqs.get("create_final_relationships") - final_nodes = dict_parqs.get("create_final_nodes") - - create_entities_table(description_embedding_store, community_level) - - covariates = ( - read_indexer_covariates(final_covariates) - if final_covariates is not None - else [] - ) - - reports_result=kt_read_indexer_reports( description_embedding_store, community_level) - - search_engine = get_local_search_engine( - config, - reports=read_indexer_reports( - final_community_reports, final_nodes, community_level - ), - text_units=read_indexer_text_units(final_text_units), - entities=[], - relationships=read_indexer_relationships(final_relationships), - covariates={"claims": covariates}, - description_embedding_store=description_embedding_store, - response_type=response_type, - ) - - result = search_engine.search(query=query) - reporter.success(f"Local Search Response: {result.response}") - return result.response - -# Create entities table similar to read_indexer_entities, but creating that table in Kusto, not in memory. -def create_entities_table(description_embedding_store: BaseVectorStore, community_level: int): - description_embedding_store.execute_query(".set-or-replace entities <| ( \ - create_final_nodes | where level <= 2 | project name=['title'] ,rank=degree,community | \ - summarize community=max(community) by name,rank | join kind=inner \ - create_final_entities on name)") - -def run_kusto_global_search( - config_dir: str | None, - data_dir: str | None, - root_dir: str | None, - community_level: int, - response_type: str, - query: str, -): - """Run a global search in Kusto with the given query.""" - raise NotImplementedError("This function is not implemented yet.") - - def _configure_paths_and_settings( data_dir: str | None, diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index 452ad30760..efab79d70f 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -194,20 +194,9 @@ def similarity_search_by_text( ) return [] - def load_parqs(self, data_path, parq_names) -> Any: - raise NotImplementedError("Loading Parquet files is not supported for Azure AI Search") - - def get_extracted_entities( - self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any ) -> list[Entity]: raise NotImplementedError("Extracting entities is not supported for Azure AI Search") - def read_parqs(self, data_dir, parq_names) -> Any: - raise NotImplementedError("Reading Parquet files is not supported for Azure AI Search") - - def get_related_entities(self, titles:list[str], **kwargs: Any) -> list[Entity]: - """Get related entities from the vector store.""" - raise NotImplementedError("Getting related entities is not supported for Azure AI Search") - - def execute_query(self, query: str) -> Any: - return super().execute_query(query) \ No newline at end of file + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + raise NotImplementedError("Loading entities is not supported for Azure AI Search") \ No newline at end of file diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index a02fc834da..104c9790c3 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -83,15 +83,6 @@ def similarity_search_by_text( def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: """Build a query filter to filter documents by id.""" - @abstractmethod - def load_parqs(self, data_path: str, parqs: list[str]) -> Any: - """Load documents (Parquet files) into the vector-store.""" - - #TODO This is temporary until I take out the client from the vector store class - @abstractmethod - def execute_query(self, query: str) -> Any: - """Execute a query in the vector-store.""" - @abstractmethod def get_extracted_entities( self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any @@ -99,9 +90,5 @@ def get_extracted_entities( """From a query, build a subtable of entities which is only matching entities.""" @abstractmethod - def read_parqs(self, data_dir, parq_names) -> Any: - """Return a dictionary of parquet dataframes of parq_name to data frame.""" - - @abstractmethod - def get_related_entities(self, titles: list[str], **kwargs: Any) -> list[Entity]: - """Get related entities from the vector store.""" \ No newline at end of file + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + """Load entities into the vector-store.""" \ No newline at end of file diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 3d135f06b0..3f8e8e9c96 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -35,15 +35,6 @@ class KustoVectorStore(BaseVectorStore): """The Azure Kusto vector storage implementation.""" - #TODO Currently loading in all the parquet fields, need to filter out the ones that are not needed. - #TODO Double check the types. This was done quickly and I may have missed something. - #TODO Check if there is a better way to get the fields to ingest into the Kusto table. These schemas are based off of me reading the files and manually making them. Maybe there is a better way to do this. - schema_dict: typing.ClassVar[dict] = {"create_final_nodes": "(level: int, title: string, type: string, description: string, source_id: string, community: int, degree: int, human_readable_id: int, id: string, size: int, graph_embedding: dynamic, entity_type: string, top_level_node_id: string, x: int, y: int)" - , "create_final_community_reports": "(community: int, full_content: string, level: int, rank: int, title: string, rank_explanation: string, summary: string, findings: string, full_content_json: string, id: string)" - , "create_final_text_units": "(id: string, text: string, n_tokens: int, document_ids: string, entity_ids: string, relationship_ids: string)" - , "create_final_relationships": "(source: string, target: string, weight: real, description: string, text_unit_ids: string, id: string, human_readable_id: string, source_degree: int, target_degree: int, rank: int)" - , "create_final_entities": "(id: string, name: string, type: string, description: string, human_readable_id: int, graph_embedding: dynamic, text_unit_ids: string, description_embedding: dynamic)"} - def connect(self, **kwargs: Any) -> Any: """ Connect to the vector storage. @@ -96,7 +87,6 @@ def load_documents( return # Convert data to DataFrame - import pandas as pd df = pd.DataFrame(data) # Create or replace table @@ -157,29 +147,15 @@ def similarity_search_by_vector( print("Similarities of the search results:", [row["similarity"] for _, row in df.iterrows()]) # Temporary to support the original entity_description_embedding - if(self.vector_name == "vector"): - return [ - VectorStoreSearchResult( - document=VectorStoreDocument( - id=row["id"], - text=row["text"], - vector=row[self.vector_name], - attributes=row["attributes"], - ), - score= 1 + float(row["similarity"]), - ) - for _, row in df.iterrows() - ] - return [ VectorStoreSearchResult( document=VectorStoreDocument( id=row["id"], - text=row["name"], + text=row["text"], vector=row[self.vector_name], - attributes={"title":row["name"]}, + attributes=row["attributes"], ), - score= 1 + float(row["similarity"]), # get a [0,2] range; work with positvie numbers + score= 1 + float(row["similarity"]), # 1 + similarity to make it a score between 0 and 2 ) for _, row in df.iterrows() ] @@ -204,120 +180,50 @@ def similarity_search_by_text( return self.similarity_search_by_vector(query_embedding, k) return [] - - def execute_query(self, query: str) -> Any: - return self.client.execute(self.database, f"{query}") - def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any ) -> list[Entity]: - query_results = self.similarity_search_by_text(text, text_embedder, k) - query_ids = [result.document.id for result in query_results] - if query_ids not in [[], None]: - ids_str = "\", \"".join([str(id) for id in query_ids]) - query = f""" - entities - | where id in ("{ids_str}") - """ - print(query) - response = self.client.execute(self.database, query) - df = dataframe_from_result_table(response.primary_results[0]) - return self.__extract_entities_from_data_frame(df) - return [] + query_embedding = text_embedder(text) + query = f""" + let query_vector = dynamic({query_embedding}); + {self.collection_name} + | extend similarity = series_cosine_similarity(query_vector, {self.vector_name}) + | top {k} by similarity desc + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) - def __extract_entities_from_data_frame(self, df: pd.DataFrame) -> list[Entity]: return [ Entity( id=row["id"], - title=row["name1"], + title=row["title"], type=row["type"], description=row["description"], graph_embedding=row["graph_embedding"], text_unit_ids=row["text_unit_ids"], description_embedding=row["description_embedding"], short_id="", - community_ids=[row["community"]], + community_ids=row["community_ids"], + document_ids=row["document_ids"], rank=row["rank"], - attributes={"title":row["name1"]}, - ) - for _, row in df.iterrows() + attributes=row["attributes"], + ) for _, row in df.iterrows() ] - def get_related_entities(self, titles: list[str], **kwargs: Any) -> list[Entity]: - """Get related entities based on the given titles.""" - titles_str = "\", \"".join(titles) + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + # Convert data to DataFrame + df = pd.DataFrame(entities) - query = f""" - create_final_relationships - | where source in ("{titles_str}") - | project name=target - | join kind=inner create_final_entities on name - """ - response = self.client.execute(self.database, query) - df = dataframe_from_result_table(response.primary_results[0]) - selected_entities = self.__extract_entities_from_data_frame(df) + # Create or replace table + if overwrite: + command = f".drop table {self.collection_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.collection_name} (id: string, short_id: real, title: string, type: string, description: string, description_embedding: dynamic, name_embedding: dynamic, graph_embedding: dynamic, community_ids: dynamic, text_unit_ids: dynamic, document_ids: dynamic, rank: real, attributes: dynamic)" + self.client.execute(self.database, command) + command = f".alter column {self.collection_name}.graph_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + command = f".alter column {self.collection_name}.description_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) - query = f""" - create_final_relationships - | where target in ("{titles_str}") - | project name=source - | join kind=inner create_final_entities on name - """ - response = self.client.execute(self.database, query) - df = dataframe_from_result_table(response.primary_results[0]) - selected_entities += self.__extract_entities_from_data_frame(df) - - return selected_entities - - def load_parqs(self, data_dir, parq_names) -> Any: - data_path = Path(data_dir) - for parq_name in parq_names: - parq_path = data_path / f"{parq_name}.parquet" - if parq_path.exists(): - parq = pd.read_parquet(parq_path) - - # I wasn't sure if was easier to rename the columns here or in the KQL queries. - # Most likely the KQL queries as this is a place I am trying to handle all the parquet files generically. - # parq.rename(columns={"id": "title"}, inplace=True) - # parq = cast(pd.DataFrame, parq[["title", "degree", "community"]]).rename( - # columns={"title": "name", "degree": "rank"} - # ) - - command = f".drop table {parq_name} ifexists" - self.client.execute(self.database, command) - command = f".create table {parq_name} {self.schema_dict[parq_name]}" - self.client.execute(self.database, command) - - # Due to an issue with to_csv not being able to handle float64, I had to manually handle entities. - if parq_name == "create_final_entities": - command = f".alter column create_final_entities.graph_embedding policy encoding type = 'Vector16'" - self.client.execute(self.database, command) - command = f".alter column create_final_entities.description_embedding policy encoding type = 'Vector16'" - self.client.execute(self.database, command) - data = [ - { - "id": to_str(row, "id"), - "name": to_str(row, "name"), - "type": to_optional_str(row, "type"), - "description": to_optional_str(row, "description"), - "human_readable_id": to_optional_str(row, "human_readable_id"), - "graph_embedding": to_optional_list(row, "graph_embedding"), - "text_unit_ids": to_optional_list(row, "text_unit_ids"), - "description_embedding": to_optional_list(row, "description_embedding"), - } - for idx, row in parq.iterrows() - ] - parq = pd.DataFrame(data) - command = f".ingest inline into table {parq_name} <| {parq.to_csv(index=False, header=False)}" - self.client.execute(self.database, command) - else: - print(f"Parquet file {parq_path} not found.") - - def read_parqs(self, data_dir, parq_names) -> Any: - """Return a dictionary of parquet dataframes of parq_name to data frame.""" - data_path = Path(data_dir) - for parq_name in parq_names: - parq_path = data_path / f"{parq_name}.parquet" - parq = None - if parq_path.exists(): - parq = pd.read_parquet(parq_path) - yield parq_name, parq + # Ingest data + ingestion_command = f".ingest inline into table {self.collection_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 021312500c..095b6bae42 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -121,20 +121,9 @@ def similarity_search_by_text( return self.similarity_search_by_vector(query_embedding, k) return [] - def load_parqs(self, data_path, parq_names) -> Any: - raise NotImplementedError("Loading Parquet files is not supported for LanceDB") - - def get_extracted_entities( - self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any ) -> list[Entity]: raise NotImplementedError("Extracting entities is not supported for LanceDB") - def read_parqs(self, data_dir, parq_names) -> Any: - raise NotImplementedError("Reading Parquet files is not supported for LanceDB") - - def get_related_entities(self, titles: list[str], **kwargs: Any) -> list[Entity]: - """Get related entities from the vector store.""" - raise NotImplementedError("Getting related entities is not supported for LanceDB") - - def execute_query(self, query: str) -> Any: - return super().execute_query(query) \ No newline at end of file + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + raise NotImplementedError("Loading entities is not supported for LanceDB") From 0039b0d1f7b55530a3cc6f06d8a101b547cea45f Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Wed, 28 Aug 2024 15:23:44 -0700 Subject: [PATCH 41/87] Fixing lancedb from Sirus suggestion --- graphrag/query/cli.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index d67e79f4ae..c8dccf3564 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -223,13 +223,16 @@ def run_local_search( else [] ) + if(isinstance(description_embedding_store, KustoVectorStore)): + entities = [] + search_engine = get_local_search_engine( config, reports=read_indexer_reports( final_community_reports, final_nodes, community_level ), text_units=read_indexer_text_units(final_text_units), - entities=[], + entities=entities, relationships=read_indexer_relationships(final_relationships), covariates={"claims": covariates}, description_embedding_store=description_embedding_store, From 16a5766544f1cdc1750ccec75f97648b57b75b4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Thu, 29 Aug 2024 16:10:23 -0700 Subject: [PATCH 42/87] Add functionality for context graph creation --- common/graph_db_client.py | 21 ++----------------- graphrag/index/cli.py | 3 ++- .../index/context_switch/contextSwitcher.py | 15 ++++++++++++- graphrag/index/emit/factories.py | 7 ++++--- graphrag/index/emit/graph_db_emitter.py | 4 ++-- graphrag/index/run.py | 6 +++++- graphrag/query/cli.py | 2 +- poetry.lock | 17 ++++++++++++++- pyproject.toml | 1 + 9 files changed, 47 insertions(+), 29 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index 483c49d2f6..4c83818585 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -13,13 +13,12 @@ import json class GraphDBClient: - def __init__(self,graph_db_params: GraphDBConfig|None): + def __init__(self,graph_db_params: GraphDBConfig|None,context_id: str|None): self.username_prefix=graph_db_params.username - self.current_context="00000000-0000-0000-0000-000000000000" self._client=client.Client( url=f"wss://{graph_db_params.account_name}.gremlin.cosmos.azure.com:443/", traversal_source="g", - username=self.username_prefix+"-contextid-"+self.current_context, + username=self.username_prefix+"-contextid-"+context_id, password=f"{graph_db_params.account_key}", message_serializer=serializer.GraphSONSerializersV2d0(), ) @@ -77,22 +76,7 @@ def element_exists(self,element_type:str,element_id:int,conditions:str="")->bool element_count=counts[0] return element_count>0 - def switch_graphs(self,context_id:str)->None: - if context_id==self.current_context: - return - self.current_context=context_id - updated_client=client.Client( - url=self._client._url, - traversal_source="g", - username=self.username_prefix+"-contextid-"+self.current_context, - password=self._client._password, - message_serializer=serializer.GraphSONSerializersV2d0(), - ) - self._client.close() - self._client=updated_client - def write_vertices(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: - self.switch_graphs(context_id) for row in data.itertuples(): if self.element_exists("g.V()",row.id): continue @@ -126,7 +110,6 @@ def write_vertices(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-00 def write_edges(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: - self.switch_graphs(context_id) for row in data.itertuples(): if self.element_exists("g.E()",row.id): continue diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index cc557a890a..f05a409fe6 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -108,7 +108,7 @@ def index_cli( ValueError("ContextOperation is invalid: It should be Active or DeActive") graphrag_config = _read_config_parameters(root, config, progress_reporter) _switch_context(graphrag_config, context_operation, context_id, progress_reporter) - sys.exit(0) + #sys.exit(0) cache = NoopPipelineCache() if nocache else None pipeline_emit = emit.split(",") if emit else None encountered_errors = False @@ -144,6 +144,7 @@ async def execute(): else None ), is_resume_run=bool(resume), + context_id=context_id, ): if output.errors and len(output.errors) > 0: encountered_errors = True diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index 7ff1fa023f..6c1e242cc1 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -1,5 +1,6 @@ from graphrag.index.progress import ProgressReporter from graphrag.config import GraphRagConfig +from azure.cosmos import CosmosClient, PartitionKey class ContextSwitcher: """ContextSwitcher class definition.""" @@ -14,7 +15,19 @@ def activate(self, config: GraphRagConfig | str, contextId: str | None, reporter #2. read the file from storage using common/blob_storage_client.py #3. GraphDB: use cosmos db client to load data into Cosmos DB. #4. KustoDB: use Kusto client to load embedding data into Kusto. - + if config.graphdb.enabled: + cosmos_client = CosmosClient( + f"https://{config.graphdb.account_name}.documents.azure.com:443/", + f"{config.graphdb.account_key}", + ) + database_name = config.graphdb.username.split("/")[2] + database = cosmos_client.get_database_client(database_name) + graph_name=config.graphdb.username.split("/")[-1]+"-contextid-"+contextId + graph = database.create_container_if_not_exists( + id=graph_name, + partition_key=PartitionKey(path='/category'), + offer_throughput=400 + ) return 0 def deactivate(self, config: GraphRagConfig | str, contextId: str | None, reporter: ProgressReporter): diff --git a/graphrag/index/emit/factories.py b/graphrag/index/emit/factories.py index b9417787af..2a7cbc0f8f 100644 --- a/graphrag/index/emit/factories.py +++ b/graphrag/index/emit/factories.py @@ -15,7 +15,7 @@ from .types import TableEmitterType def create_table_emitter( - emitter_type: TableEmitterType, storage: PipelineStorage, on_error: ErrorHandlerFn, graphdb_params: GraphDBConfig|None = None + emitter_type: TableEmitterType, storage: PipelineStorage, on_error: ErrorHandlerFn, graphdb_params: GraphDBConfig|None = None, context_id: str|None = None ) -> TableEmitter: """Create a table emitter based on the specified type.""" match emitter_type: @@ -26,7 +26,7 @@ def create_table_emitter( case TableEmitterType.CSV: return CSVTableEmitter(storage) case TableEmitterType.Graphdb: - return GraphDBEmitter(graphdb_params) + return GraphDBEmitter(graphdb_params,context_id) case _: msg = f"Unsupported table emitter type: {emitter_type}" raise ValueError(msg) @@ -37,9 +37,10 @@ def create_table_emitters( storage: PipelineStorage, on_error: ErrorHandlerFn, graphdb_params: GraphDBConfig|None = None, + context_id: str|None = None, ) -> list[TableEmitter]: """Create a list of table emitters based on the specified types.""" return [ - create_table_emitter(emitter_type, storage, on_error, graphdb_params) + create_table_emitter(emitter_type, storage, on_error, graphdb_params,context_id) for emitter_type in emitter_types ] diff --git a/graphrag/index/emit/graph_db_emitter.py b/graphrag/index/emit/graph_db_emitter.py index 5365835e77..f0f7ce95ba 100644 --- a/graphrag/index/emit/graph_db_emitter.py +++ b/graphrag/index/emit/graph_db_emitter.py @@ -19,8 +19,8 @@ class GraphDBEmitter(TableEmitter): - def __init__(self, graph_db_params: GraphDBConfig|None): - self.graph_db_client = GraphDBClient(graph_db_params) + def __init__(self, graph_db_params: GraphDBConfig|None,context_id: str|None): + self.graph_db_client = GraphDBClient(graph_db_params,context_id) self.allowed_workflows = ['create_final_entities','create_final_relationships'] async def emit(self, name: str, data: pd.DataFrame) -> None: diff --git a/graphrag/index/run.py b/graphrag/index/run.py index 7ea6faea2b..7b3be7738b 100644 --- a/graphrag/index/run.py +++ b/graphrag/index/run.py @@ -83,6 +83,7 @@ async def run_pipeline_with_config( memory_profile: bool = False, run_id: str | None = None, is_resume_run: bool = False, + context_id: str | None = None, **_kwargs: dict, ) -> AsyncIterable[PipelineRunResult]: """Run a pipeline with the given config. @@ -166,6 +167,7 @@ def _create_postprocess_steps( emit=emit, is_resume_run=is_resume_run, graphdb_params = config.graphdb_params, + context_id=context_id, ): yield table @@ -184,6 +186,7 @@ async def run_pipeline( memory_profile: bool = False, is_resume_run: bool = False, graphdb_params: GraphDBConfig|None = None, + context_id: str | None = None, **_kwargs: dict, ) -> AsyncIterable[PipelineRunResult]: """Run the pipeline. @@ -219,7 +222,8 @@ async def run_pipeline( lambda e, s, d: cast(WorkflowCallbacks, callbacks).on_error( "Error emitting table", e, s, d ), - graphdb_params + graphdb_params, + context_id, ) loaded_workflows = load_workflows( workflows, diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 609de371f8..27ac3cc4f8 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -169,7 +169,7 @@ def run_local_search( final_entities = pd.DataFrame() final_covariates = pd.DataFrame() if config.graphdb.enabled: - graph_db_client = GraphDBClient(config.graphdb) + graph_db_client = GraphDBClient(config.graphdb,context_id) for data_path in data_paths: #check from the config for the ouptut storage type and then read the data from the storage. diff --git a/poetry.lock b/poetry.lock index fbacd85265..ce5edf4733 100644 --- a/poetry.lock +++ b/poetry.lock @@ -412,6 +412,21 @@ typing-extensions = ">=4.6.0" [package.extras] aio = ["aiohttp (>=3.0)"] +[[package]] +name = "azure-cosmos" +version = "4.7.0" +description = "Microsoft Azure Cosmos Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure-cosmos-4.7.0.tar.gz", hash = "sha256:72d714033134656302a2e8957c4b93590673bd288b0ca60cb123e348ae99a241"}, + {file = "azure_cosmos-4.7.0-py3-none-any.whl", hash = "sha256:03d8c7740ddc2906fb16e07b136acc0fe6a6a02656db46c5dd6f1b127b58cc96"}, +] + +[package.dependencies] +azure-core = ">=1.25.1" +typing-extensions = ">=4.6.0" + [[package]] name = "azure-identity" version = "1.17.1" @@ -5858,4 +5873,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "f06dbe201e3dfea982b0b052a3d6811e1be7acde113f3276f61390bd80684447" +content-hash = "3486484392ad6f778df36fc15643b7a0ffa7495893ba15c66ad467f946dca692" diff --git a/pyproject.toml b/pyproject.toml index 6978359eb2..4ae07d7552 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ azure-storage-blob = "^12.19.0" azure-identity = "^1.17.1" json-repair = "^0.25.3" gremlinpython = "^3.7.2" +azure-cosmos = "^4.7.0" [tool.poetry.group.dev.dependencies] coverage = "^7.6.0" From b1d4f221d6cae01333b58aedb214f9d04aec6256 Mon Sep 17 00:00:00 2001 From: sirus-ms Date: Fri, 30 Aug 2024 09:49:51 -0700 Subject: [PATCH 43/87] Kusto context-switch * Kusto context-switch --- graphrag/index/__main__.py | 7 + graphrag/index/cli.py | 14 +- .../index/context_switch/contextSwitcher.py | 239 +++++++++++++++++- graphrag/query/__main__.py | 2 +- graphrag/query/cli.py | 8 +- graphrag/vector_stores/kusto.py | 3 + 6 files changed, 257 insertions(+), 16 deletions(-) diff --git a/graphrag/index/__main__.py b/graphrag/index/__main__.py index 7bdac87734..9a652ae07a 100644 --- a/graphrag/index/__main__.py +++ b/graphrag/index/__main__.py @@ -81,6 +81,12 @@ help="Overlay default configuration values on a provided configuration file (--config).", action="store_true", ) + parser.add_argument( + "--community_level", + help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities", + type=int, + default=2, + ) args = parser.parse_args() @@ -102,4 +108,5 @@ cli=True, context_id=args.context_id, context_operation=args.context_operation, + community_level=args.community_level ) diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 0a370f42a7..67b9f5a804 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -73,6 +73,7 @@ def redact_dict(input: dict) -> dict: def index_cli( root: str, init: bool, + community_level: int, context_operation: str | None, context_id: str | None, verbose: bool, @@ -106,8 +107,8 @@ def index_cli( ValueError("ContextId is invalid: It should be a valid Guid") if (context_operation != ContextSwitchType.Activate and context_operation != ContextSwitchType.Deactivate): ValueError("ContextOperation is invalid: It should be Active or DeActive") - graphrag_config = _read_config_parameters(root, config, progress_reporter) - _switch_context(graphrag_config, context_operation, context_id, progress_reporter) + #graphrag_config = _read_config_parameters(root, config, progress_reporter) + _switch_context(config,root,context_operation,context_id,progress_reporter,community_level) sys.exit(0) cache = NoopPipelineCache() if nocache else None pipeline_emit = emit.split(",") if emit else None @@ -182,15 +183,16 @@ async def execute(): if cli: sys.exit(1 if encountered_errors else 0) -def _switch_context(config: GraphRagConfig | str, context_operation: str | None, context_id: str, reporter: ProgressReporter) -> None: +def _switch_context(config: GraphRagConfig | str, root: str , context_operation: str | None, + context_id: str, reporter: ProgressReporter,community_level: int) -> None: """Switch the context to the given context.""" reporter.info(f"Switching context to {context_id} using operation {context_operation}") from graphrag.index.context_switch.contextSwitcher import ContextSwitcher - context_switcher = ContextSwitcher() + context_switcher = ContextSwitcher(root,config,reporter,context_id,community_level) if context_operation == ContextSwitchType.Activate: - context_switcher.activate(config, context_id, reporter) + context_switcher.activate() elif context_operation == ContextSwitchType.Deactivate: - context_switcher.deactivate(config, context_id, reporter) + context_switcher.deactivate() else: msg = f"Invalid context operation {context_operation}" raise ValueError(msg) diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index f992eb6354..aa3d06d73d 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -1,22 +1,247 @@ from graphrag.common.progress import ProgressReporter from graphrag.config import GraphRagConfig +from graphrag.config.enums import StorageType +from graphrag.common.storage import PipelineStorage, BlobPipelineStorage, FilePipelineStorage +from graphrag.common.utils.context_utils import get_files_by_contextid +import pandas as pd +from typing import cast +from azure.core.exceptions import ResourceNotFoundError +import asyncio +from io import BytesIO +from pathlib import Path +from graphrag.config import ( + create_graphrag_config, + GraphRagConfig, +) +from common.graph_db_client import GraphDBClient +import os +from graphrag.vector_stores import VectorStoreFactory, VectorStoreType +from graphrag.vector_stores.base import BaseVectorStore +from graphrag.vector_stores.lancedb import LanceDBVectorStore +from graphrag.vector_stores.kusto import KustoVectorStore +from graphrag.query.indexer_adapters import ( + read_indexer_covariates, + read_indexer_entities, + read_indexer_relationships, + read_indexer_reports, + kt_read_indexer_reports, + read_indexer_text_units, +) +from graphrag.model.entity import Entity class ContextSwitcher: """ContextSwitcher class definition.""" - def __init__(self): - #initialize Gremline and Cosmos Db client here. - pass - def activate(self, config: GraphRagConfig | str, contextId: str | None, reporter: ProgressReporter): + def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter, + context_id:str, community_level:int , + data_dir: str = None, + optimized_search: bool= False): + + self.root_dir=root_dir + self.config_dir=config_dir + self.data_dir=data_dir + self.reporter=reporter + self.context_id=context_id + self.optimized_search=optimized_search + self.community_level = community_level + + def set_ctx_activation( + self, + activate: int, + entities: list[Entity]=[], + config_args: dict | None = None, + ): + if not config_args: + config_args = {} + + collection_name = config_args.get( + "query_collection_name", "entity_description_embeddings" + ) + + collection_name += "_" + self.context_id + config_args.update({"collection_name": collection_name}) + vector_name = config_args.get( + "vector_search_column", "description_embedding" + ) + config_args.update({"vector_name": vector_name}) + + description_embedding_store = VectorStoreFactory.get_vector_store( + vector_store_type=VectorStoreType.Kusto, kwargs=config_args + ) + description_embedding_store.connect(**config_args) + + if activate: + description_embedding_store.load_entities(entities) + else: + description_embedding_store.unload_entities() + + return 0 + + def activate(self): """Activate the context.""" #1. read the context id to fileId mapping. #2. read the file from storage using common/blob_storage_client.py #3. GraphDB: use cosmos db client to load data into Cosmos DB. #4. KustoDB: use Kusto client to load embedding data into Kusto. - return 0 + data_dir=self.data_dir + root_dir=self.root_dir + config_dir=self.config_dir + reporter=self.reporter + context_id=self.context_id + optimized_search=self.optimized_search + community_level=self.community_level + + def read_paraquet_file(storage: PipelineStorage, path: str): + #create different enum for paraquet storage type + file_data = asyncio.run(storage.get(path, True)) + if file_data is None: + return pd.DataFrame() + return pd.read_parquet(BytesIO(file_data), engine="pyarrow") + + def _configure_paths_and_settings( + data_dir: str | None, + root_dir: str | None, + config_dir: str | None, + ) -> tuple[str, str | None, GraphRagConfig]: + if data_dir is None and root_dir is None: + msg = "Either data_dir or root_dir must be provided." + raise ValueError(msg) + if data_dir is None: + data_dir = _infer_data_dir(cast(str, root_dir)) + config = _create_graphrag_config(root_dir, config_dir) + return data_dir, root_dir, config + + + def _infer_data_dir(root: str) -> str: + output = Path(root) / "output" + # use the latest data-run folder + if output.exists(): + folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True) + if len(folders) > 0: + folder = folders[0] + return str((folder / "artifacts").absolute()) + msg = f"Could not infer data directory from root={root}" + raise ValueError(msg) + + + def _create_graphrag_config( + root: str | None, + config_dir: str | None, + ) -> GraphRagConfig: + """Create a GraphRag configuration.""" + return _read_config_parameters(root or "./", config_dir) + + + def _read_config_parameters(root: str, config: str | None): + _root = Path(root) + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) + if not settings_yaml.exists(): + settings_yaml = _root / "settings.yml" + + if settings_yaml.exists(): + reporter.info(f"Reading settings from {settings_yaml}") + with settings_yaml.open( + "rb", + ) as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) + if settings_json.exists(): + reporter.info(f"Reading settings from {settings_json}") + with settings_json.open("rb") as file: + import json + + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + reporter.info("Reading settings from environment variables") + return create_graphrag_config(root_dir=root) + + + + ################################################################################ + + + _, _, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + + if(config.storage.type == StorageType.memory): + ValueError("Memory storage is not supported") + if(config.storage.type == StorageType.blob): + if(config.storage.container_name is not None): + input_storage_client: PipelineStorage = BlobPipelineStorage(config.storage.connection_string, config.storage.container_name) + output_storage_client: PipelineStorage = BlobPipelineStorage(config.storage.connection_string, config.storage.container_name) + else: + ValueError("Storage type is Blob but container name is invalid") + if(config.storage.type == StorageType.file): + input_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + output_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + + data_paths = [] + data_paths = get_files_by_contextid(config, context_id) + final_nodes = pd.DataFrame() + final_community_reports = pd.DataFrame() + final_text_units = pd.DataFrame() + final_relationships = pd.DataFrame() + final_entities = pd.DataFrame() + final_covariates = pd.DataFrame() + if config.graphdb.enabled: + graph_db_client = GraphDBClient(config.graphdb) + for data_path in data_paths: + #check from the config for the ouptut storage type and then read the data from the storage. + + #GraphDB: we may need to make change below to read nodes data from Graph DB + final_nodes = pd.concat([final_nodes, read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet")]) + final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto + final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. + + if not optimized_search: + final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) + + if config.graphdb.enabled: + final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) + final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) + else: + final_relationships = pd.concat([final_relationships, read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) + final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) + + + vector_store_args = ( + config.embeddings.vector_store if config.embeddings.vector_store else {} + ) + + reporter.info(f"Vector Store Args: {vector_store_args}") + + if "type" not in vector_store_args: + ValueError("vectore_store.type can't be empty") + + vector_store_type = vector_store_args.get("type") + + if vector_store_type != VectorStoreType.Kusto: + ValueError("Context switching is only supporeted for vectore_store.type=kusto ") + + entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. + + self.set_ctx_activation( + entities=entities, + activate=1, config_args=vector_store_args, + ) + - def deactivate(self, config: GraphRagConfig | str, contextId: str | None, reporter: ProgressReporter): + def deactivate(self): """DeActivate the context.""" #1. Delete all the data for a given context id. - return 0 \ No newline at end of file + self.set_ctx_activation(0) \ No newline at end of file diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index 4ee502bd89..1d1cd25b4f 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -75,7 +75,7 @@ def __str__(self): "--context_id", help="Guid describing context in which the search should be performed", type=str, - default="00000000-0000-0000-0000-000000000000", + #default="00000000-0000-0000-0000-000000000000", ) parser.add_argument( diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index c8dccf3564..689a942705 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -69,7 +69,7 @@ def __get_embedding_description_store( description_embedding_store.connect(**config_args) if vector_store_type == VectorStoreType.Kusto: - description_embedding_store.load_entities(entities) + return description_embedding_store elif config_args.get("overwrite", True): # this step assumps the embeddings where originally stored in a file rather @@ -161,6 +161,8 @@ def run_local_search( data_dir, root_dir, config_dir ) + # TODO: loading stage here must be only limited to default lancedb. + # for the POC purpose input artifacts blob, output artifacts blob and input query blob storage are going to same. if(config.storage.type == StorageType.memory): ValueError("Memory storage is not supported") @@ -208,9 +210,11 @@ def run_local_search( ) reporter.info(f"Vector Store Args: {vector_store_args}") - vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) # verify kusto vector store here. + vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. + + description_embedding_store = __get_embedding_description_store( entities=entities, vector_store_type=vector_store_type, diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 3f8e8e9c96..e12c872239 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -208,6 +208,9 @@ def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int attributes=row["attributes"], ) for _, row in df.iterrows() ] + + def unload_entities(self) -> None: + self.client.execute(self.database,f".drop table {self.collection_name} ifexists") def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: # Convert data to DataFrame From 05e721c5103fbb51c6487401b4ac8c6c73dfe0bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Mon, 26 Aug 2024 15:25:33 -0700 Subject: [PATCH 44/87] Include context into read and write calls for graphdb --- common/graph_db_client.py | 98 +++++++++++++++++++++++++++------------ graphrag/query/cli.py | 4 +- 2 files changed, 70 insertions(+), 32 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index db1d467df2..483c49d2f6 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -14,10 +14,12 @@ class GraphDBClient: def __init__(self,graph_db_params: GraphDBConfig|None): + self.username_prefix=graph_db_params.username + self.current_context="00000000-0000-0000-0000-000000000000" self._client=client.Client( url=f"wss://{graph_db_params.account_name}.gremlin.cosmos.azure.com:443/", traversal_source="g", - username=graph_db_params.username, + username=self.username_prefix+"-contextid-"+self.current_context, password=f"{graph_db_params.account_key}", message_serializer=serializer.GraphSONSerializersV2d0(), ) @@ -42,7 +44,7 @@ def result_to_df(self,result) -> pd.DataFrame: df = pd.DataFrame(json_data) return df - def query_vertices(self) -> pd.DataFrame: + def query_vertices(self,context_id:str="00000000-0000-0000-0000-000000000000") -> pd.DataFrame: result = self._client.submit( message=( "g.V()" @@ -50,7 +52,7 @@ def query_vertices(self) -> pd.DataFrame: ) return self.result_to_df(result) - def query_edges(self) -> pd.DataFrame: + def query_edges(self,context_id:str="00000000-0000-0000-0000-000000000000") -> pd.DataFrame: result = self._client.submit( message=( "g.E()" @@ -58,40 +60,76 @@ def query_edges(self) -> pd.DataFrame: ) return self.result_to_df(result) - def write_vertices(self,data: pd.DataFrame)->None: - for row in data.itertuples(): - print(row.id) - self._client.submit( + def element_exists(self,element_type:str,element_id:int,conditions:str="")->bool: + result=self._client.submit( message=( - "g.addV('entity')" - ".property('id', prop_id)" - ".property('name', prop_name)" - ".property('type', prop_type)" - ".property('description','prop_description')" - ".property('human_readable_id', prop_human_readable_id)" - ".property('category', prop_partition_key)" - ".property(list,'description_embedding',prop_description_embedding)" - ".property(list,'graph_embedding',prop_graph_embedding)" - ".property(list,'text_unit_ids',prop_text_unit_ids)" + element_type+ + ".has('id',prop_id)"+ + conditions+ + ".count()" ), bindings={ - "prop_id": row.id, - "prop_name": row.name, - "prop_type": row.type, - "prop_description": row.description, - "prop_human_readable_id": row.human_readable_id, - "prop_partition_key": "entities", - "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), - "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), - "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), - }, - ) + "prop_id":element_id, + } + ) + element_count=0 + for counts in result: + element_count=counts[0] + return element_count>0 + + def switch_graphs(self,context_id:str)->None: + if context_id==self.current_context: + return + self.current_context=context_id + updated_client=client.Client( + url=self._client._url, + traversal_source="g", + username=self.username_prefix+"-contextid-"+self.current_context, + password=self._client._password, + message_serializer=serializer.GraphSONSerializersV2d0(), + ) + self._client.close() + self._client=updated_client + + def write_vertices(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: + self.switch_graphs(context_id) + for row in data.itertuples(): + if self.element_exists("g.V()",row.id): + continue + else: + self._client.submit( + message=( + "g.addV('entity')" + ".property('id', prop_id)" + ".property('name', prop_name)" + ".property('type', prop_type)" + ".property('description','prop_description')" + ".property('human_readable_id', prop_human_readable_id)" + ".property('category', prop_partition_key)" + ".property(list,'description_embedding',prop_description_embedding)" + ".property(list,'graph_embedding',prop_graph_embedding)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ), + bindings={ + "prop_id": row.id, + "prop_name": row.name, + "prop_type": row.type, + "prop_description": row.description, + "prop_human_readable_id": row.human_readable_id, + "prop_partition_key": "entities", + "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), + "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), + "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), + }, + ) time.sleep(5) - def write_edges(self,data: pd.DataFrame)->None: + def write_edges(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: + self.switch_graphs(context_id) for row in data.itertuples(): - print(row.source,row.target) + if self.element_exists("g.E()",row.id): + continue self._client.submit( message=( "g.V().has('name',prop_source_id)" diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 689a942705..3457029aa9 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -198,8 +198,8 @@ def run_local_search( final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) if config.graphdb.enabled: - final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) - final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) + final_relationships = pd.concat([final_relationships, graph_db_client.query_edges(context_id)]) + final_entities = pd.concat([final_entities, graph_db_client.query_vertices(context_id)]) else: final_relationships = pd.concat([final_relationships, read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) From 689b9832bb09cb7326b889ddc81573a5b5105485 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Thu, 29 Aug 2024 16:10:23 -0700 Subject: [PATCH 45/87] Add functionality for context graph creation --- common/graph_db_client.py | 21 ++----------------- graphrag/index/cli.py | 3 ++- .../index/context_switch/contextSwitcher.py | 17 ++++++++++++--- graphrag/index/emit/factories.py | 7 ++++--- graphrag/index/emit/graph_db_emitter.py | 4 ++-- graphrag/index/run.py | 6 +++++- graphrag/query/cli.py | 2 +- poetry.lock | 17 ++++++++++++++- pyproject.toml | 1 + 9 files changed, 47 insertions(+), 31 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index 483c49d2f6..4c83818585 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -13,13 +13,12 @@ import json class GraphDBClient: - def __init__(self,graph_db_params: GraphDBConfig|None): + def __init__(self,graph_db_params: GraphDBConfig|None,context_id: str|None): self.username_prefix=graph_db_params.username - self.current_context="00000000-0000-0000-0000-000000000000" self._client=client.Client( url=f"wss://{graph_db_params.account_name}.gremlin.cosmos.azure.com:443/", traversal_source="g", - username=self.username_prefix+"-contextid-"+self.current_context, + username=self.username_prefix+"-contextid-"+context_id, password=f"{graph_db_params.account_key}", message_serializer=serializer.GraphSONSerializersV2d0(), ) @@ -77,22 +76,7 @@ def element_exists(self,element_type:str,element_id:int,conditions:str="")->bool element_count=counts[0] return element_count>0 - def switch_graphs(self,context_id:str)->None: - if context_id==self.current_context: - return - self.current_context=context_id - updated_client=client.Client( - url=self._client._url, - traversal_source="g", - username=self.username_prefix+"-contextid-"+self.current_context, - password=self._client._password, - message_serializer=serializer.GraphSONSerializersV2d0(), - ) - self._client.close() - self._client=updated_client - def write_vertices(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: - self.switch_graphs(context_id) for row in data.itertuples(): if self.element_exists("g.V()",row.id): continue @@ -126,7 +110,6 @@ def write_vertices(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-00 def write_edges(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: - self.switch_graphs(context_id) for row in data.itertuples(): if self.element_exists("g.E()",row.id): continue diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 67b9f5a804..b8e3c51a37 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -109,7 +109,7 @@ def index_cli( ValueError("ContextOperation is invalid: It should be Active or DeActive") #graphrag_config = _read_config_parameters(root, config, progress_reporter) _switch_context(config,root,context_operation,context_id,progress_reporter,community_level) - sys.exit(0) + #sys.exit(0) cache = NoopPipelineCache() if nocache else None pipeline_emit = emit.split(",") if emit else None encountered_errors = False @@ -145,6 +145,7 @@ async def execute(): else None ), is_resume_run=bool(resume), + context_id=context_id, ): if output.errors and len(output.errors) > 0: encountered_errors = True diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index aa3d06d73d..d4f982bbd3 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -28,6 +28,7 @@ read_indexer_text_units, ) from graphrag.model.entity import Entity +from azure.cosmos import CosmosClient, PartitionKey class ContextSwitcher: """ContextSwitcher class definition.""" @@ -83,7 +84,6 @@ def activate(self): #2. read the file from storage using common/blob_storage_client.py #3. GraphDB: use cosmos db client to load data into Cosmos DB. #4. KustoDB: use Kusto client to load embedding data into Kusto. - data_dir=self.data_dir root_dir=self.root_dir config_dir=self.config_dir @@ -199,7 +199,19 @@ def _read_config_parameters(root: str, config: str | None): final_entities = pd.DataFrame() final_covariates = pd.DataFrame() if config.graphdb.enabled: - graph_db_client = GraphDBClient(config.graphdb) + cosmos_client = CosmosClient( + f"https://{config.graphdb.account_name}.documents.azure.com:443/", + f"{config.graphdb.account_key}", + ) + database_name = config.graphdb.username.split("/")[2] + database = cosmos_client.get_database_client(database_name) + graph_name=config.graphdb.username.split("/")[-1]+"-contextid-"+context_id + graph = database.create_container_if_not_exists( + id=graph_name, + partition_key=PartitionKey(path='/category'), + offer_throughput=400 + ) + graph_db_client = GraphDBClient(config.graphdb,context_id) for data_path in data_paths: #check from the config for the ouptut storage type and then read the data from the storage. @@ -240,7 +252,6 @@ def _read_config_parameters(root: str, config: str | None): activate=1, config_args=vector_store_args, ) - def deactivate(self): """DeActivate the context.""" #1. Delete all the data for a given context id. diff --git a/graphrag/index/emit/factories.py b/graphrag/index/emit/factories.py index edabba4506..1c4e218785 100644 --- a/graphrag/index/emit/factories.py +++ b/graphrag/index/emit/factories.py @@ -15,7 +15,7 @@ from .types import TableEmitterType def create_table_emitter( - emitter_type: TableEmitterType, storage: PipelineStorage, on_error: ErrorHandlerFn, graphdb_params: GraphDBConfig|None = None + emitter_type: TableEmitterType, storage: PipelineStorage, on_error: ErrorHandlerFn, graphdb_params: GraphDBConfig|None = None, context_id: str|None = None ) -> TableEmitter: """Create a table emitter based on the specified type.""" match emitter_type: @@ -26,7 +26,7 @@ def create_table_emitter( case TableEmitterType.CSV: return CSVTableEmitter(storage) case TableEmitterType.Graphdb: - return GraphDBEmitter(graphdb_params) + return GraphDBEmitter(graphdb_params,context_id) case _: msg = f"Unsupported table emitter type: {emitter_type}" raise ValueError(msg) @@ -37,9 +37,10 @@ def create_table_emitters( storage: PipelineStorage, on_error: ErrorHandlerFn, graphdb_params: GraphDBConfig|None = None, + context_id: str|None = None, ) -> list[TableEmitter]: """Create a list of table emitters based on the specified types.""" return [ - create_table_emitter(emitter_type, storage, on_error, graphdb_params) + create_table_emitter(emitter_type, storage, on_error, graphdb_params,context_id) for emitter_type in emitter_types ] diff --git a/graphrag/index/emit/graph_db_emitter.py b/graphrag/index/emit/graph_db_emitter.py index 4bb4244fa4..d8018ee678 100644 --- a/graphrag/index/emit/graph_db_emitter.py +++ b/graphrag/index/emit/graph_db_emitter.py @@ -11,8 +11,8 @@ class GraphDBEmitter(TableEmitter): """Graph DB Emitter.""" - def __init__(self, graph_db_params: GraphDBConfig|None): - self.graph_db_client = GraphDBClient(graph_db_params) + def __init__(self, graph_db_params: GraphDBConfig|None,context_id: str|None): + self.graph_db_client = GraphDBClient(graph_db_params,context_id) self.allowed_workflows = ['create_final_entities','create_final_relationships'] async def emit(self, name: str, data: pd.DataFrame) -> None: diff --git a/graphrag/index/run.py b/graphrag/index/run.py index 7ca8bb9264..27416f7d7f 100644 --- a/graphrag/index/run.py +++ b/graphrag/index/run.py @@ -83,6 +83,7 @@ async def run_pipeline_with_config( memory_profile: bool = False, run_id: str | None = None, is_resume_run: bool = False, + context_id: str | None = None, **_kwargs: dict, ) -> AsyncIterable[PipelineRunResult]: """Run a pipeline with the given config. @@ -166,6 +167,7 @@ def _create_postprocess_steps( emit=emit, is_resume_run=is_resume_run, graphdb_params = config.graphdb_params, + context_id=context_id, ): yield table @@ -184,6 +186,7 @@ async def run_pipeline( memory_profile: bool = False, is_resume_run: bool = False, graphdb_params: GraphDBConfig|None = None, + context_id: str | None = None, **_kwargs: dict, ) -> AsyncIterable[PipelineRunResult]: """Run the pipeline. @@ -219,7 +222,8 @@ async def run_pipeline( lambda e, s, d: cast(WorkflowCallbacks, callbacks).on_error( "Error emitting table", e, s, d ), - graphdb_params + graphdb_params, + context_id, ) loaded_workflows = load_workflows( workflows, diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 3457029aa9..44d6d1bb55 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -185,7 +185,7 @@ def run_local_search( final_entities = pd.DataFrame() final_covariates = pd.DataFrame() if config.graphdb.enabled: - graph_db_client = GraphDBClient(config.graphdb) + graph_db_client = GraphDBClient(config.graphdb,context_id) for data_path in data_paths: #check from the config for the ouptut storage type and then read the data from the storage. diff --git a/poetry.lock b/poetry.lock index fbacd85265..ce5edf4733 100644 --- a/poetry.lock +++ b/poetry.lock @@ -412,6 +412,21 @@ typing-extensions = ">=4.6.0" [package.extras] aio = ["aiohttp (>=3.0)"] +[[package]] +name = "azure-cosmos" +version = "4.7.0" +description = "Microsoft Azure Cosmos Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure-cosmos-4.7.0.tar.gz", hash = "sha256:72d714033134656302a2e8957c4b93590673bd288b0ca60cb123e348ae99a241"}, + {file = "azure_cosmos-4.7.0-py3-none-any.whl", hash = "sha256:03d8c7740ddc2906fb16e07b136acc0fe6a6a02656db46c5dd6f1b127b58cc96"}, +] + +[package.dependencies] +azure-core = ">=1.25.1" +typing-extensions = ">=4.6.0" + [[package]] name = "azure-identity" version = "1.17.1" @@ -5858,4 +5873,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "f06dbe201e3dfea982b0b052a3d6811e1be7acde113f3276f61390bd80684447" +content-hash = "3486484392ad6f778df36fc15643b7a0ffa7495893ba15c66ad467f946dca692" diff --git a/pyproject.toml b/pyproject.toml index 6978359eb2..4ae07d7552 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ azure-storage-blob = "^12.19.0" azure-identity = "^1.17.1" json-repair = "^0.25.3" gremlinpython = "^3.7.2" +azure-cosmos = "^4.7.0" [tool.poetry.group.dev.dependencies] coverage = "^7.6.0" From 628a4c2fda1b0e5525ce8d53e0ae3e5b0cae45ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Fri, 30 Aug 2024 11:05:37 -0700 Subject: [PATCH 46/87] Address comments --- common/graph_db_client.py | 8 ++++---- graphrag/index/cli.py | 2 +- graphrag/index/context_switch/contextSwitcher.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index 4c83818585..5962484a07 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -43,7 +43,7 @@ def result_to_df(self,result) -> pd.DataFrame: df = pd.DataFrame(json_data) return df - def query_vertices(self,context_id:str="00000000-0000-0000-0000-000000000000") -> pd.DataFrame: + def query_vertices(self,context_id:str) -> pd.DataFrame: result = self._client.submit( message=( "g.V()" @@ -51,7 +51,7 @@ def query_vertices(self,context_id:str="00000000-0000-0000-0000-000000000000") - ) return self.result_to_df(result) - def query_edges(self,context_id:str="00000000-0000-0000-0000-000000000000") -> pd.DataFrame: + def query_edges(self,context_id:str) -> pd.DataFrame: result = self._client.submit( message=( "g.E()" @@ -76,7 +76,7 @@ def element_exists(self,element_type:str,element_id:int,conditions:str="")->bool element_count=counts[0] return element_count>0 - def write_vertices(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: + def write_vertices(self,data: pd.DataFrame)->None: for row in data.itertuples(): if self.element_exists("g.V()",row.id): continue @@ -109,7 +109,7 @@ def write_vertices(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-00 time.sleep(5) - def write_edges(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: + def write_edges(self,data: pd.DataFrame)->None: for row in data.itertuples(): if self.element_exists("g.E()",row.id): continue diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index b8e3c51a37..1e33e61e4c 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -109,7 +109,7 @@ def index_cli( ValueError("ContextOperation is invalid: It should be Active or DeActive") #graphrag_config = _read_config_parameters(root, config, progress_reporter) _switch_context(config,root,context_operation,context_id,progress_reporter,community_level) - #sys.exit(0) + sys.exit(0) cache = NoopPipelineCache() if nocache else None pipeline_emit = emit.split(",") if emit else None encountered_errors = False diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index d4f982bbd3..ccf1c0f90b 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -224,8 +224,8 @@ def _read_config_parameters(root: str, config: str | None): final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) if config.graphdb.enabled: - final_relationships = pd.concat([final_relationships, graph_db_client.query_edges()]) - final_entities = pd.concat([final_entities, graph_db_client.query_vertices()]) + final_relationships = pd.concat([final_relationships, graph_db_client.query_edges(context_id)]) + final_entities = pd.concat([final_entities, graph_db_client.query_vertices(context_id)]) else: final_relationships = pd.concat([final_relationships, read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) From fd76db5fcdc31c2e0ad8628460dfeae98fa89b3c Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Fri, 30 Aug 2024 10:45:36 -0700 Subject: [PATCH 47/87] Adding community reports to Kusto --- graphrag/query/cli.py | 16 ++++--- .../context_builder/entity_extraction.py | 3 ++ graphrag/query/indexer_adapters.py | 1 + .../local_search/mixed_context.py | 22 ++++++--- graphrag/vector_stores/base.py | 13 ++++- graphrag/vector_stores/kusto.py | 48 +++++++++++++++++++ 6 files changed, 88 insertions(+), 15 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 44d6d1bb55..1609b7a5f0 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -194,8 +194,8 @@ def run_local_search( final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. - if not optimized_search: - final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) + # if not optimized_search: + # final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) if config.graphdb.enabled: final_relationships = pd.concat([final_relationships, graph_db_client.query_edges(context_id)]) @@ -213,8 +213,9 @@ def run_local_search( vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. - - + reports=read_indexer_reports( + final_community_reports, final_nodes, community_level + ) description_embedding_store = __get_embedding_description_store( entities=entities, vector_store_type=vector_store_type, @@ -227,14 +228,15 @@ def run_local_search( else [] ) + if(isinstance(description_embedding_store, KustoVectorStore)): entities = [] + description_embedding_store.load_reports(reports) + reports = [] search_engine = get_local_search_engine( config, - reports=read_indexer_reports( - final_community_reports, final_nodes, community_level - ), + reports=reports, text_units=read_indexer_text_units(final_text_units), entities=entities, relationships=read_indexer_relationships(final_relationships), diff --git a/graphrag/query/context_builder/entity_extraction.py b/graphrag/query/context_builder/entity_extraction.py index fe8253b612..037da80932 100644 --- a/graphrag/query/context_builder/entity_extraction.py +++ b/graphrag/query/context_builder/entity_extraction.py @@ -46,6 +46,9 @@ def map_query_to_entities_in_place( text_embedder=lambda t: text_embedder.embed(t), k=k * oversample_scaler, ) + import ast + for result in search_results: + result.community_ids = ast.literal_eval(result.community_ids) return search_results def map_query_to_entities( diff --git a/graphrag/query/indexer_adapters.py b/graphrag/query/indexer_adapters.py index c3b5f6e1ab..101fc16f9c 100644 --- a/graphrag/query/indexer_adapters.py +++ b/graphrag/query/indexer_adapters.py @@ -94,6 +94,7 @@ def read_indexer_reports( report_df = _filter_under_community_level(report_df, community_level) report_df = report_df.merge(filtered_community_df, on="community", how="inner") + report_df = report_df.drop_duplicates(subset=["community"]) return read_community_reports( df=report_df, diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index b26d0581ee..0d36cce761 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -43,6 +43,7 @@ from graphrag.query.llm.text_utils import num_tokens from graphrag.query.structured_search.base import LocalContextBuilder from graphrag.vector_stores import BaseVectorStore +from graphrag.vector_stores.kusto import KustoVectorStore log = logging.getLogger(__name__) @@ -238,7 +239,7 @@ def _build_community_context( is_optimized_search: bool = False, ) -> tuple[str, dict[str, pd.DataFrame]]: """Add community data to the context window until it hits the max_tokens limit.""" - if len(selected_entities) == 0 or len(self.community_reports) == 0: + if len(selected_entities) == 0 or (len(self.community_reports) == 0 and isinstance(self.entity_text_embeddings, KustoVectorStore)): return ("", {context_name.lower(): pd.DataFrame()}) community_matches = {} @@ -250,12 +251,19 @@ def _build_community_context( community_matches.get(community_id, 0) + 1 ) + selected_communities = [] + if len(self.community_reports) == 0: + selected_communities = self.entity_text_embeddings.get_extracted_communities( + community_ids=list(community_matches.keys()) + ) + else: + selected_communities = [ + self.community_reports[community_id] + for community_id in community_matches + if community_id in self.community_reports + ] + # sort communities by number of matched entities and rank - selected_communities = [ - self.community_reports[community_id] - for community_id in community_matches - if community_id in self.community_reports - ] for community in selected_communities: if community.attributes is None: community.attributes = {} @@ -450,7 +458,7 @@ def _build_local_context( relationship_context, self.token_encoder ) - + # build covariate context for covariate in self.covariates: covariate_context, covariate_context_data = build_covariates_context( diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index 104c9790c3..390aea159f 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from typing import Any +from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder @@ -91,4 +92,14 @@ def get_extracted_entities( @abstractmethod def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: - """Load entities into the vector-store.""" \ No newline at end of file + """Load entities into the vector-store.""" + + @abstractmethod + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + """Load reports into the vector-store.""" + + @abstractmethod + def get_extracted_communities( + self, community_ids: list[int], **kwargs: Any + ) -> list[CommunityReport]: + """Get reports for a given list of community ids.""" \ No newline at end of file diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index e12c872239..740e557f88 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -6,6 +6,7 @@ import typing from azure.kusto.data import KustoClient, KustoConnectionStringBuilder from azure.kusto.data.helpers import dataframe_from_result_table +from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder @@ -230,3 +231,50 @@ def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: # Ingest data ingestion_command = f".ingest inline into table {self.collection_name} <| {df.to_csv(index=False, header=False)}" self.client.execute(self.database, ingestion_command) + + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + # Convert data to DataFrame + df = pd.DataFrame(reports) + table_name = "reports" + + # Create or replace table + if overwrite: + command = f".drop table {table_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {table_name} (id: string, short_id: string, title: string, community_id: string, summary: string, full_content: string, rank: real, summary_embedding: dynamic, full_content_embedding: dynamic, attributes: dynamic)" + self.client.execute(self.database, command) + command = f".alter column {table_name}.summary_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + command = f".alter column {table_name}.full_content_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + + # Ingest data + ingestion_command = f".ingest inline into table {table_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + + def get_extracted_communities( + self, community_ids: list[int], **kwargs: Any + ) -> list[CommunityReport]: + community_ids = ", ".join([str(id) for id in community_ids]) + query = f""" + reports + | where community_id in ({community_ids}) + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + + return [ + CommunityReport( + id=row["id"], + short_id=row["short_id"], + title=row["title"], + community_id=row["community_id"], + summary=row["summary"], + full_content=row["full_content"], + rank=row["rank"], + summary_embedding=row["summary_embedding"], + full_content_embedding=row["full_content_embedding"], + attributes=row["attributes"], + ) for _, row in df.iterrows() + ] From aa6095bd65836f6f16138f0f2be4413858bc745d Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Fri, 30 Aug 2024 14:36:01 -0700 Subject: [PATCH 48/87] Adding configurations use_kusto_community_reports, updating - to _ configurations, adding reports into activation/load --- graphrag/index/__main__.py | 14 ++++-- graphrag/index/cli.py | 17 +++++-- .../index/context_switch/contextSwitcher.py | 50 +++++++++++-------- graphrag/query/__main__.py | 13 +++-- graphrag/query/cli.py | 17 ++++--- graphrag/query/factories.py | 4 +- .../local_search/mixed_context.py | 8 +-- graphrag/vector_stores/azure_ai_search.py | 11 +++- graphrag/vector_stores/base.py | 4 +- graphrag/vector_stores/kusto.py | 15 +++--- graphrag/vector_stores/lancedb.py | 11 +++- 11 files changed, 107 insertions(+), 57 deletions(-) diff --git a/graphrag/index/__main__.py b/graphrag/index/__main__.py index 9a652ae07a..a501f829df 100644 --- a/graphrag/index/__main__.py +++ b/graphrag/index/__main__.py @@ -53,13 +53,13 @@ type=str, ) parser.add_argument( - "--context-id", + "--context_id", required=False, help="Context id to activate or deactivate.", type=str ) parser.add_argument( - "--context-operation", + "--context_operation", help="Context operation activate or deactivate.", required=False, # Only required if contextId is provided @@ -77,7 +77,7 @@ action="store_true", ) parser.add_argument( - "--overlay-defaults", + "--overlay_defaults", help="Overlay default configuration values on a provided configuration file (--config).", action="store_true", ) @@ -87,6 +87,11 @@ type=int, default=2, ) + parser.add_argument( + "--use_kusto_community_reports", + help="If enabled community reports are loaded into Kusto during activation", + action="store_true", + ) args = parser.parse_args() @@ -108,5 +113,6 @@ cli=True, context_id=args.context_id, context_operation=args.context_operation, - community_level=args.community_level + community_level=args.community_level, + use_kusto_community_reports=args.use_kusto_community_reports ) diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 1e33e61e4c..8fd3a80dd6 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -86,6 +86,7 @@ def index_cli( dryrun: bool, overlay_defaults: bool, cli: bool = False, + use_kusto_community_reports: bool = False, ): """Run the pipeline with the given config.""" run_id = resume or time.strftime("%Y%m%d-%H%M%S") @@ -107,8 +108,15 @@ def index_cli( ValueError("ContextId is invalid: It should be a valid Guid") if (context_operation != ContextSwitchType.Activate and context_operation != ContextSwitchType.Deactivate): ValueError("ContextOperation is invalid: It should be Active or DeActive") - #graphrag_config = _read_config_parameters(root, config, progress_reporter) - _switch_context(config,root,context_operation,context_id,progress_reporter,community_level) + _switch_context( + config, + root, + context_operation, + context_id, + progress_reporter, + community_level, + use_kusto_community_reports, + ) sys.exit(0) cache = NoopPipelineCache() if nocache else None pipeline_emit = emit.split(",") if emit else None @@ -185,11 +193,12 @@ async def execute(): sys.exit(1 if encountered_errors else 0) def _switch_context(config: GraphRagConfig | str, root: str , context_operation: str | None, - context_id: str, reporter: ProgressReporter,community_level: int) -> None: + context_id: str, reporter: ProgressReporter,community_level: int, + use_kusto_community_reports: bool) -> None: """Switch the context to the given context.""" reporter.info(f"Switching context to {context_id} using operation {context_operation}") from graphrag.index.context_switch.contextSwitcher import ContextSwitcher - context_switcher = ContextSwitcher(root,config,reporter,context_id,community_level) + context_switcher = ContextSwitcher(root,config, reporter,context_id,community_level,use_kusto_community_reports) if context_operation == ContextSwitchType.Activate: context_switcher.activate() elif context_operation == ContextSwitchType.Deactivate: diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index ccf1c0f90b..4b902d42d2 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -1,31 +1,29 @@ -from graphrag.common.progress import ProgressReporter -from graphrag.config import GraphRagConfig -from graphrag.config.enums import StorageType -from graphrag.common.storage import PipelineStorage, BlobPipelineStorage, FilePipelineStorage -from graphrag.common.utils.context_utils import get_files_by_contextid -import pandas as pd -from typing import cast -from azure.core.exceptions import ResourceNotFoundError import asyncio +import os from io import BytesIO from pathlib import Path +from typing import cast + +import pandas as pd + +from common.graph_db_client import GraphDBClient +from graphrag.common.progress import ProgressReporter +from graphrag.common.storage import ( + BlobPipelineStorage, + FilePipelineStorage, + PipelineStorage, +) +from graphrag.common.utils.context_utils import get_files_by_contextid from graphrag.config import ( - create_graphrag_config, GraphRagConfig, + create_graphrag_config, ) -from common.graph_db_client import GraphDBClient -import os -from graphrag.vector_stores import VectorStoreFactory, VectorStoreType -from graphrag.vector_stores.base import BaseVectorStore -from graphrag.vector_stores.lancedb import LanceDBVectorStore -from graphrag.vector_stores.kusto import KustoVectorStore +from graphrag.config.enums import StorageType +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity from graphrag.query.indexer_adapters import ( - read_indexer_covariates, read_indexer_entities, - read_indexer_relationships, read_indexer_reports, - kt_read_indexer_reports, - read_indexer_text_units, ) from graphrag.model.entity import Entity from azure.cosmos import CosmosClient, PartitionKey @@ -36,7 +34,8 @@ class ContextSwitcher: def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter, context_id:str, community_level:int , data_dir: str = None, - optimized_search: bool= False): + optimized_search: bool= False, + use_kusto_community_reports: bool = False,): self.root_dir=root_dir self.config_dir=config_dir @@ -45,11 +44,13 @@ def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter, self.context_id=context_id self.optimized_search=optimized_search self.community_level = community_level + self.use_kusto_community_reports = use_kusto_community_reports def set_ctx_activation( self, activate: int, entities: list[Entity]=[], + reports: list[CommunityReport]=[], config_args: dict | None = None, ): if not config_args: @@ -65,6 +66,7 @@ def set_ctx_activation( "vector_search_column", "description_embedding" ) config_args.update({"vector_name": vector_name}) + config_args.update({"reports_name": f"reports_{self.context_id}"}) description_embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=VectorStoreType.Kusto, kwargs=config_args @@ -73,8 +75,11 @@ def set_ctx_activation( if activate: description_embedding_store.load_entities(entities) + if self.use_kusto_community_reports: + description_embedding_store.load_reports(reports) else: description_embedding_store.unload_entities() + # I don't think it is necessary to unload anything as the retention policy will take care of it. return 0 @@ -246,10 +251,13 @@ def _read_config_parameters(root: str, config: str | None): ValueError("Context switching is only supporeted for vectore_store.type=kusto ") entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. + reports = read_indexer_reports(final_community_reports, final_nodes, community_level) self.set_ctx_activation( entities=entities, - activate=1, config_args=vector_store_args, + reports=reports, + activate=1, + config_args=vector_store_args, ) def deactivate(self): diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index 1d1cd25b4f..e0a5e9f583 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -16,8 +16,6 @@ class SearchType(Enum): LOCAL = "local" GLOBAL = "global" - KUSTO_LOCAL = "kusto_local" - KUSTO_GLOBAL = "kusto_global" def __str__(self): """Return the string representation of the enum value.""" @@ -85,6 +83,12 @@ def __str__(self): default=False, ) + parser.add_argument( + "--use_kusto_community_reports", + help="If enabled community reports are attempted to be used in Kusto during query", + action="store_true", + ) + parser.add_argument( "query", nargs=1, @@ -92,8 +96,6 @@ def __str__(self): type=str, ) - - args = parser.parse_args() match args.method: @@ -106,7 +108,8 @@ def __str__(self): args.response_type, args.context_id, args.query[0], - optimized_search=args.optimized_search + optimized_search=args.optimized_search, + use_kusto_community_reports=args.use_kusto_community_reports, ) case SearchType.GLOBAL: run_global_search( diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 1609b7a5f0..8fab982b55 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -45,9 +45,10 @@ reporter = PrintProgressReporter("") def __get_embedding_description_store( - entities: list[Entity], + entities: list[Entity] = [], vector_store_type: str = VectorStoreType.LanceDB, config_args: dict | None = None, + context_id: str = "", ): """Get the embedding description store.""" if not config_args: @@ -56,11 +57,12 @@ def __get_embedding_description_store( collection_name = config_args.get( "query_collection_name", "entity_description_embeddings" ) - config_args.update({"collection_name": collection_name}) + config_args.update({"collection_name": f"{collection_name}_{context_id}" if context_id else collection_name}) vector_name = config_args.get( "vector_search_column", "description_embedding" ) config_args.update({"vector_name": vector_name}) + config_args.update({"reports_name": f"reports_{context_id}" if context_id else "reports"}) description_embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=vector_store_type, kwargs=config_args @@ -155,6 +157,7 @@ def run_local_search( context_id: str, query: str, optimized_search: bool = False, + use_kusto_community_reports: bool = False, ): """Run a local search with the given query.""" data_dir, root_dir, config = _configure_paths_and_settings( @@ -194,8 +197,8 @@ def run_local_search( final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. - # if not optimized_search: - # final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) + if not optimized_search: + final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) if config.graphdb.enabled: final_relationships = pd.concat([final_relationships, graph_db_client.query_edges(context_id)]) @@ -220,6 +223,7 @@ def run_local_search( entities=entities, vector_store_type=vector_store_type, config_args=vector_store_args, + context_id=context_id, ) covariates = ( @@ -228,11 +232,11 @@ def run_local_search( else [] ) - if(isinstance(description_embedding_store, KustoVectorStore)): entities = [] description_embedding_store.load_reports(reports) - reports = [] + if use_kusto_community_reports: + reports = [] search_engine = get_local_search_engine( config, @@ -244,6 +248,7 @@ def run_local_search( description_embedding_store=description_embedding_store, response_type=response_type, is_optimized_search=optimized_search, + use_kusto_community_reports=use_kusto_community_reports, ) if optimized_search: diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index 3dfe104230..e9775f2bf2 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -108,7 +108,8 @@ def get_local_search_engine( covariates: dict[str, list[Covariate]], response_type: str, description_embedding_store: BaseVectorStore, - is_optimized_search: bool = False + is_optimized_search: bool = False, + use_kusto_community_reports: bool = False, ) -> LocalSearch: """Create a local search engine based on data + configuration.""" llm = get_llm(config) @@ -130,6 +131,7 @@ def get_local_search_engine( text_embedder=text_embedder, token_encoder=token_encoder, is_optimized_search= is_optimized_search, + use_kusto_community_reports=use_kusto_community_reports, ), token_encoder=token_encoder, llm_params={ diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index 0d36cce761..988d290848 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -63,6 +63,7 @@ def __init__( token_encoder: tiktoken.Encoding | None = None, embedding_vectorstore_key: str = EntityVectorStoreKey.ID, is_optimized_search: bool = False, + use_kusto_community_reports: bool = False, ): if community_reports is None: community_reports = [] @@ -86,6 +87,7 @@ def __init__( self.token_encoder = token_encoder self.embedding_vectorstore_key = embedding_vectorstore_key self.is_optimized_search = is_optimized_search + self.use_kusto_community_reports = use_kusto_community_reports def filter_by_entity_keys(self, entity_keys: list[int] | list[str]): """Filter entity text embeddings by entity keys.""" @@ -239,7 +241,7 @@ def _build_community_context( is_optimized_search: bool = False, ) -> tuple[str, dict[str, pd.DataFrame]]: """Add community data to the context window until it hits the max_tokens limit.""" - if len(selected_entities) == 0 or (len(self.community_reports) == 0 and isinstance(self.entity_text_embeddings, KustoVectorStore)): + if len(selected_entities) == 0 or (len(self.community_reports) == 0 and not self.use_kusto_community_reports): return ("", {context_name.lower(): pd.DataFrame()}) community_matches = {} @@ -252,8 +254,8 @@ def _build_community_context( ) selected_communities = [] - if len(self.community_reports) == 0: - selected_communities = self.entity_text_embeddings.get_extracted_communities( + if self.use_kusto_community_reports: + selected_communities = self.entity_text_embeddings.get_extracted_reports( community_ids=list(community_matches.keys()) ) else: diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index efab79d70f..afb0c18393 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -24,6 +24,7 @@ ) from azure.search.documents.models import VectorizedQuery +from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder @@ -194,9 +195,15 @@ def similarity_search_by_text( ) return [] + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + raise NotImplementedError("Loading entities is not supported for Azure AI Search") + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any ) -> list[Entity]: raise NotImplementedError("Extracting entities is not supported for Azure AI Search") - def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: - raise NotImplementedError("Loading entities is not supported for Azure AI Search") \ No newline at end of file + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + raise NotImplementedError("Loading reports is not supported for Azure AI Search") + + def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: + raise NotImplementedError("Extracting reports is not supported for Azure AI Search") \ No newline at end of file diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index 390aea159f..67121a04b9 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -46,6 +46,7 @@ def __init__( self, collection_name: str, vector_name: str, + reports_name: str, db_connection: Any | None = None, document_collection: Any | None = None, query_filter: Any | None = None, @@ -53,6 +54,7 @@ def __init__( ): self.collection_name = collection_name self.vector_name = vector_name + self.reports_name = reports_name self.db_connection = db_connection self.document_collection = document_collection self.query_filter = query_filter @@ -99,7 +101,7 @@ def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) - """Load reports into the vector-store.""" @abstractmethod - def get_extracted_communities( + def get_extracted_reports( self, community_ids: list[int], **kwargs: Any ) -> list[CommunityReport]: """Get reports for a given list of community ids.""" \ No newline at end of file diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 740e557f88..ec969e1f5f 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -209,7 +209,7 @@ def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int attributes=row["attributes"], ) for _, row in df.iterrows() ] - + def unload_entities(self) -> None: self.client.execute(self.database,f".drop table {self.collection_name} ifexists") @@ -235,25 +235,24 @@ def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: # Convert data to DataFrame df = pd.DataFrame(reports) - table_name = "reports" # Create or replace table if overwrite: - command = f".drop table {table_name} ifexists" + command = f".drop table {self.reports_name} ifexists" self.client.execute(self.database, command) - command = f".create table {table_name} (id: string, short_id: string, title: string, community_id: string, summary: string, full_content: string, rank: real, summary_embedding: dynamic, full_content_embedding: dynamic, attributes: dynamic)" + command = f".create table {self.reports_name} (id: string, short_id: string, title: string, community_id: string, summary: string, full_content: string, rank: real, summary_embedding: dynamic, full_content_embedding: dynamic, attributes: dynamic)" self.client.execute(self.database, command) - command = f".alter column {table_name}.summary_embedding policy encoding type = 'Vector16'" + command = f".alter column {self.reports_name}.summary_embedding policy encoding type = 'Vector16'" self.client.execute(self.database, command) - command = f".alter column {table_name}.full_content_embedding policy encoding type = 'Vector16'" + command = f".alter column {self.reports_name}.full_content_embedding policy encoding type = 'Vector16'" self.client.execute(self.database, command) # Ingest data - ingestion_command = f".ingest inline into table {table_name} <| {df.to_csv(index=False, header=False)}" + ingestion_command = f".ingest inline into table {self.reports_name} <| {df.to_csv(index=False, header=False)}" self.client.execute(self.database, ingestion_command) - def get_extracted_communities( + def get_extracted_reports( self, community_ids: list[int], **kwargs: Any ) -> list[CommunityReport]: community_ids = ", ".join([str(id) for id in community_ids]) diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 095b6bae42..a32b8abb51 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -4,6 +4,7 @@ """The LanceDB vector storage implementation package.""" import lancedb as lancedb # noqa: I001 (Ruff was breaking on this file imports, even tho they were sorted and passed local tests) +from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder @@ -121,9 +122,15 @@ def similarity_search_by_text( return self.similarity_search_by_vector(query_embedding, k) return [] + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + raise NotImplementedError("Loading entities is not supported for LanceDB") + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any ) -> list[Entity]: raise NotImplementedError("Extracting entities is not supported for LanceDB") - def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: - raise NotImplementedError("Loading entities is not supported for LanceDB") + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + raise NotImplementedError("Loading reports is not supported for LanceDB") + + def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: + raise NotImplementedError("Extracting community reports is not supported for LanceDB") From 3efb530f461b3340a79d72086e1b48b3e31f4f82 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Tue, 3 Sep 2024 08:28:34 -0700 Subject: [PATCH 49/87] Seperate out setup so that context switcher can call setup & load seperately without keeping track of overwrite --- graphrag/vector_stores/azure_ai_search.py | 8 +++++- graphrag/vector_stores/base.py | 10 ++++++- graphrag/vector_stores/kusto.py | 34 ++++++++++++----------- graphrag/vector_stores/lancedb.py | 6 ++++ 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index afb0c18393..7e8de25f38 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -206,4 +206,10 @@ def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) - raise NotImplementedError("Loading reports is not supported for Azure AI Search") def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: - raise NotImplementedError("Extracting reports is not supported for Azure AI Search") \ No newline at end of file + raise NotImplementedError("Extracting reports is not supported for Azure AI Search") + + def setup_entities(self) -> None: + raise NotImplementedError("Setting up entities is not supported for Azure AI Search") + + def setup_reports(self) -> None: + raise NotImplementedError("Setting up reports is not supported for Azure AI Search") \ No newline at end of file diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index 67121a04b9..f198b5a34b 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -104,4 +104,12 @@ def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) - def get_extracted_reports( self, community_ids: list[int], **kwargs: Any ) -> list[CommunityReport]: - """Get reports for a given list of community ids.""" \ No newline at end of file + """Get reports for a given list of community ids.""" + + @abstractmethod + def setup_entities(self) -> None: + """Setup the entities in the vector-store.""" + + @abstractmethod + def setup_reports(self) -> None: + """Setup the reports in the vector-store.""" \ No newline at end of file diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index ec969e1f5f..4277332406 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -213,39 +213,41 @@ def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int def unload_entities(self) -> None: self.client.execute(self.database,f".drop table {self.collection_name} ifexists") + def setup_entities(self) -> None: + command = f".create table {self.collection_name} (id: string, short_id: real, title: string, type: string, description: string, description_embedding: dynamic, name_embedding: dynamic, graph_embedding: dynamic, community_ids: dynamic, text_unit_ids: dynamic, document_ids: dynamic, rank: real, attributes: dynamic)" + self.client.execute(self.database, command) + command = f".alter column {self.collection_name}.graph_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + command = f".alter column {self.collection_name}.description_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: # Convert data to DataFrame df = pd.DataFrame(entities) # Create or replace table if overwrite: - command = f".drop table {self.collection_name} ifexists" - self.client.execute(self.database, command) - command = f".create table {self.collection_name} (id: string, short_id: real, title: string, type: string, description: string, description_embedding: dynamic, name_embedding: dynamic, graph_embedding: dynamic, community_ids: dynamic, text_unit_ids: dynamic, document_ids: dynamic, rank: real, attributes: dynamic)" - self.client.execute(self.database, command) - command = f".alter column {self.collection_name}.graph_embedding policy encoding type = 'Vector16'" - self.client.execute(self.database, command) - command = f".alter column {self.collection_name}.description_embedding policy encoding type = 'Vector16'" - self.client.execute(self.database, command) + self.setup_entities() # Ingest data ingestion_command = f".ingest inline into table {self.collection_name} <| {df.to_csv(index=False, header=False)}" self.client.execute(self.database, ingestion_command) + def setup_reports(self) -> None: + command = f".create table {self.reports_name} (id: string, short_id: string, title: string, community_id: string, summary: string, full_content: string, rank: real, summary_embedding: dynamic, full_content_embedding: dynamic, attributes: dynamic)" + self.client.execute(self.database, command) + command = f".alter column {self.reports_name}.summary_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + command = f".alter column {self.reports_name}.full_content_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: # Convert data to DataFrame df = pd.DataFrame(reports) # Create or replace table if overwrite: - command = f".drop table {self.reports_name} ifexists" - self.client.execute(self.database, command) - command = f".create table {self.reports_name} (id: string, short_id: string, title: string, community_id: string, summary: string, full_content: string, rank: real, summary_embedding: dynamic, full_content_embedding: dynamic, attributes: dynamic)" - self.client.execute(self.database, command) - command = f".alter column {self.reports_name}.summary_embedding policy encoding type = 'Vector16'" - self.client.execute(self.database, command) - command = f".alter column {self.reports_name}.full_content_embedding policy encoding type = 'Vector16'" - self.client.execute(self.database, command) + self.setup_reports() # Ingest data ingestion_command = f".ingest inline into table {self.reports_name} <| {df.to_csv(index=False, header=False)}" diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index a32b8abb51..f9e0e40528 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -134,3 +134,9 @@ def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) - def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: raise NotImplementedError("Extracting community reports is not supported for LanceDB") + + def setup_entities(self) -> None: + raise NotImplementedError("Setting up entities is not supported for LanceDB") + + def setup_reports(self) -> None: + raise NotImplementedError("Setting up community reports is not supported for LanceDB") \ No newline at end of file From c80510ac7bade53926711c8c756f4e0a972e5f82 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Tue, 3 Sep 2024 11:01:02 -0700 Subject: [PATCH 50/87] Fixed bug where defaults for vector store weren't being set for indexing. --- graphrag/index/create_pipeline_config.py | 1 + graphrag/index/verbs/text/embed/text_embed.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index cf69e0bbdf..7cf91ec308 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -353,6 +353,7 @@ def _graph_workflows( "title_column": "description", "collection_name": "entity_description_embeddings", "vector_name": "vector", + "reports_name": "reports", }, ), "skip_name_embedding": skip_entity_name_embedding, diff --git a/graphrag/index/verbs/text/embed/text_embed.py b/graphrag/index/verbs/text/embed/text_embed.py index 30f572d05b..7148a6ed1b 100644 --- a/graphrag/index/verbs/text/embed/text_embed.py +++ b/graphrag/index/verbs/text/embed/text_embed.py @@ -83,9 +83,11 @@ async def text_embed( if vector_store_config: embedding_name = kwargs.get("embedding_name", "default") + vector_name = kwargs.get("vector_name", "vector") collection_name = _get_collection_name(vector_store_config, embedding_name) + vector_name = _get_collection_name(vector_store_config, vector_name) vector_store: BaseVectorStore = _create_vector_store( - vector_store_config, collection_name + vector_store_config, collection_name, vector_name, "reports" ) vector_store_workflow_config = vector_store_config.get( embedding_name, vector_store_config @@ -222,11 +224,15 @@ async def _text_embed_with_vector_store( def _create_vector_store( - vector_store_config: dict, collection_name: str + vector_store_config: dict, collection_name: str, vector_name: str, reports_name: str, ) -> BaseVectorStore: vector_store_type: str = str(vector_store_config.get("type")) if collection_name: vector_store_config.update({"collection_name": collection_name}) + if vector_name: + vector_store_config.update({"vector_name": vector_name}) + if reports_name: + vector_store_config.update({"reports_name": reports_name}) vector_store = VectorStoreFactory.get_vector_store( vector_store_type, kwargs=vector_store_config From 67e1033fbe62ef869ac1d228fc3288e5cc9aecbd Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Tue, 3 Sep 2024 11:01:58 -0700 Subject: [PATCH 51/87] Setup for vector store happens once per activation instead of for every set of parquet files. --- .../index/context_switch/contextSwitcher.py | 35 +++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index 4b902d42d2..98f2532c40 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -27,6 +27,8 @@ ) from graphrag.model.entity import Entity from azure.cosmos import CosmosClient, PartitionKey +from graphrag.vector_stores.base import BaseVectorStore +from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType class ContextSwitcher: """ContextSwitcher class definition.""" @@ -46,13 +48,9 @@ def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter, self.community_level = community_level self.use_kusto_community_reports = use_kusto_community_reports - def set_ctx_activation( - self, - activate: int, - entities: list[Entity]=[], - reports: list[CommunityReport]=[], - config_args: dict | None = None, - ): + def setup_vector_store(self, + config_args: dict | None = None,) -> BaseVectorStore: + """Set up the vector store and return it.""" if not config_args: config_args = {} @@ -73,15 +71,8 @@ def set_ctx_activation( ) description_embedding_store.connect(**config_args) - if activate: - description_embedding_store.load_entities(entities) - if self.use_kusto_community_reports: - description_embedding_store.load_reports(reports) - else: - description_embedding_store.unload_entities() - # I don't think it is necessary to unload anything as the retention policy will take care of it. - - return 0 + description_embedding_store.setup_entities() + return description_embedding_store def activate(self): """Activate the context.""" @@ -217,6 +208,9 @@ def _read_config_parameters(root: str, config: str | None): offer_throughput=400 ) graph_db_client = GraphDBClient(config.graphdb,context_id) + + description_embedding_store = self.setup_vector_store(config_args=config.embeddings.vector_store) + for data_path in data_paths: #check from the config for the ouptut storage type and then read the data from the storage. @@ -253,12 +247,9 @@ def _read_config_parameters(root: str, config: str | None): entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. reports = read_indexer_reports(final_community_reports, final_nodes, community_level) - self.set_ctx_activation( - entities=entities, - reports=reports, - activate=1, - config_args=vector_store_args, - ) + description_embedding_store.load_entities(entities) + if self.use_kusto_community_reports: + description_embedding_store.load_reports(reports) def deactivate(self): """DeActivate the context.""" From c353d686b1c040748934d5298f2e90e7073470bb Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Tue, 3 Sep 2024 14:31:11 -0700 Subject: [PATCH 52/87] Arg mismatch & not calling load kusto in query anymore. --- graphrag/index/__main__.py | 9 ++++++++- graphrag/index/cli.py | 31 +++++++++++++++++++++---------- graphrag/query/cli.py | 1 - 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/graphrag/index/__main__.py b/graphrag/index/__main__.py index a501f829df..de2c156a69 100644 --- a/graphrag/index/__main__.py +++ b/graphrag/index/__main__.py @@ -92,6 +92,12 @@ help="If enabled community reports are loaded into Kusto during activation", action="store_true", ) + parser.add_argument( + "--optimized_search", + help="Runs optimized search and export artifacts", + type=bool, + default=False, + ) args = parser.parse_args() @@ -114,5 +120,6 @@ context_id=args.context_id, context_operation=args.context_operation, community_level=args.community_level, - use_kusto_community_reports=args.use_kusto_community_reports + use_kusto_community_reports=args.use_kusto_community_reports, + optimized_search=args.optimized_search ) diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 8fd3a80dd6..63fa9c190b 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -87,6 +87,7 @@ def index_cli( overlay_defaults: bool, cli: bool = False, use_kusto_community_reports: bool = False, + optimized_search: bool = False, ): """Run the pipeline with the given config.""" run_id = resume or time.strftime("%Y%m%d-%H%M%S") @@ -109,13 +110,14 @@ def index_cli( if (context_operation != ContextSwitchType.Activate and context_operation != ContextSwitchType.Deactivate): ValueError("ContextOperation is invalid: It should be Active or DeActive") _switch_context( - config, - root, - context_operation, - context_id, - progress_reporter, - community_level, - use_kusto_community_reports, + root=root, + config=config, + reporter=progress_reporter, + context_operation=context_operation, + context_id=context_id, + community_level=community_level, + optimized_search=optimized_search, + use_kusto_community_reports=use_kusto_community_reports, ) sys.exit(0) cache = NoopPipelineCache() if nocache else None @@ -192,13 +194,22 @@ async def execute(): if cli: sys.exit(1 if encountered_errors else 0) -def _switch_context(config: GraphRagConfig | str, root: str , context_operation: str | None, - context_id: str, reporter: ProgressReporter,community_level: int, +def _switch_context(root: str, config: str, + reporter: ProgressReporter, context_operation: str | None, + context_id: str, community_level: int, optimized_search: bool, use_kusto_community_reports: bool) -> None: """Switch the context to the given context.""" reporter.info(f"Switching context to {context_id} using operation {context_operation}") from graphrag.index.context_switch.contextSwitcher import ContextSwitcher - context_switcher = ContextSwitcher(root,config, reporter,context_id,community_level,use_kusto_community_reports) + context_switcher = ContextSwitcher( + root_dir=root, + config_dir=config, + reporter=reporter, + context_id=context_id, + community_level=community_level, + data_dir=None, + optimized_search=optimized_search, + use_kusto_community_reports=use_kusto_community_reports) if context_operation == ContextSwitchType.Activate: context_switcher.activate() elif context_operation == ContextSwitchType.Deactivate: diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 8fab982b55..c60c99524b 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -234,7 +234,6 @@ def run_local_search( if(isinstance(description_embedding_store, KustoVectorStore)): entities = [] - description_embedding_store.load_reports(reports) if use_kusto_community_reports: reports = [] From dae1bee4ff1ff7acc4ac64fee556389f4c088c07 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Wed, 4 Sep 2024 07:18:36 -0700 Subject: [PATCH 53/87] Adding report_name var to query, load_ doesn't overwrite automatically, setup_ drops table to reset state. --- graphrag/vector_stores/kusto.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 4277332406..71113c5d04 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -214,6 +214,8 @@ def unload_entities(self) -> None: self.client.execute(self.database,f".drop table {self.collection_name} ifexists") def setup_entities(self) -> None: + command = f".drop table {self.collection_name} ifexists" + self.client.execute(self.database, command) command = f".create table {self.collection_name} (id: string, short_id: real, title: string, type: string, description: string, description_embedding: dynamic, name_embedding: dynamic, graph_embedding: dynamic, community_ids: dynamic, text_unit_ids: dynamic, document_ids: dynamic, rank: real, attributes: dynamic)" self.client.execute(self.database, command) command = f".alter column {self.collection_name}.graph_embedding policy encoding type = 'Vector16'" @@ -221,7 +223,7 @@ def setup_entities(self) -> None: command = f".alter column {self.collection_name}.description_embedding policy encoding type = 'Vector16'" self.client.execute(self.database, command) - def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + def load_entities(self, entities: list[Entity], overwrite: bool = False) -> None: # Convert data to DataFrame df = pd.DataFrame(entities) @@ -234,6 +236,8 @@ def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: self.client.execute(self.database, ingestion_command) def setup_reports(self) -> None: + command = f".drop table {self.reports_name} ifexists" + self.client.execute(self.database, command) command = f".create table {self.reports_name} (id: string, short_id: string, title: string, community_id: string, summary: string, full_content: string, rank: real, summary_embedding: dynamic, full_content_embedding: dynamic, attributes: dynamic)" self.client.execute(self.database, command) command = f".alter column {self.reports_name}.summary_embedding policy encoding type = 'Vector16'" @@ -241,7 +245,7 @@ def setup_reports(self) -> None: command = f".alter column {self.reports_name}.full_content_embedding policy encoding type = 'Vector16'" self.client.execute(self.database, command) - def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + def load_reports(self, reports: list[CommunityReport], overwrite: bool = False) -> None: # Convert data to DataFrame df = pd.DataFrame(reports) @@ -259,7 +263,7 @@ def get_extracted_reports( ) -> list[CommunityReport]: community_ids = ", ".join([str(id) for id in community_ids]) query = f""" - reports + {self.reports_name} | where community_id in ({community_ids}) """ response = self.client.execute(self.database, query) From eb6cbe4932219aff6824e274f5eb3361811cb15c Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Wed, 4 Sep 2024 07:19:33 -0700 Subject: [PATCH 54/87] Get rid of concat in context_switcher so each file gets uploaded seperately --- .../index/context_switch/contextSwitcher.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index 98f2532c40..0ebf5845ae 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -72,6 +72,9 @@ def setup_vector_store(self, description_embedding_store.connect(**config_args) description_embedding_store.setup_entities() + if self.use_kusto_community_reports: + description_embedding_store.setup_reports() + return description_embedding_store def activate(self): @@ -215,19 +218,19 @@ def _read_config_parameters(root: str, config: str | None): #check from the config for the ouptut storage type and then read the data from the storage. #GraphDB: we may need to make change below to read nodes data from Graph DB - final_nodes = pd.concat([final_nodes, read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet")]) - final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto - final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. + final_nodes = read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet") + final_community_reports = read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet") # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto + final_text_units = read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet") # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. if not optimized_search: - final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) + final_covariates = read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet") if config.graphdb.enabled: - final_relationships = pd.concat([final_relationships, graph_db_client.query_edges(context_id)]) - final_entities = pd.concat([final_entities, graph_db_client.query_vertices(context_id)]) + final_relationships = graph_db_client.query_edges(context_id) + final_entities = graph_db_client.query_vertices(context_id) else: - final_relationships = pd.concat([final_relationships, read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) - final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) + final_relationships = read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet") + final_entities = read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet") vector_store_args = ( From 560bbe1f2c479dba4c5abba19c1a392712d9f2cc Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Wed, 4 Sep 2024 11:08:14 -0700 Subject: [PATCH 55/87] Configuring in_memory embedding storage even with vector_store configuration --- graphrag/index/verbs/text/embed/text_embed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphrag/index/verbs/text/embed/text_embed.py b/graphrag/index/verbs/text/embed/text_embed.py index 7148a6ed1b..cd9bbc798d 100644 --- a/graphrag/index/verbs/text/embed/text_embed.py +++ b/graphrag/index/verbs/text/embed/text_embed.py @@ -81,7 +81,7 @@ async def text_embed( """ vector_store_config = strategy.get("vector_store") - if vector_store_config: + if vector_store_config and not vector_store_config.get("index_in_memory"): embedding_name = kwargs.get("embedding_name", "default") vector_name = kwargs.get("vector_name", "vector") collection_name = _get_collection_name(vector_store_config, embedding_name) From fa4251087e501c76c82e3121d1c80f5c3380dfd8 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Wed, 4 Sep 2024 13:24:21 -0700 Subject: [PATCH 56/87] Change entity ID generation --- .../verbs/graph/clustering/cluster_graph.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/graphrag/index/verbs/graph/clustering/cluster_graph.py b/graphrag/index/verbs/graph/clustering/cluster_graph.py index e8be50e8ff..0cfb929c63 100644 --- a/graphrag/index/verbs/graph/clustering/cluster_graph.py +++ b/graphrag/index/verbs/graph/clustering/cluster_graph.py @@ -15,6 +15,7 @@ from graphrag.index.utils import gen_uuid, load_graph from .typing import Communities +from hashlib import sha256 log = logging.getLogger(__name__) @@ -108,13 +109,17 @@ def cluster_graph( return TableContainer(table=output_df) +def generate_entity_id(candidate: str) -> str: + h=sha256() + h.update(candidate.encode()) + return h.hexdigest() # TODO: This should support str | nx.Graph as a graphml param def apply_clustering( - graphml: str, communities: Communities, level=0, seed=0xF001 + graphml: str, communities: Communities, level=0 ) -> nx.Graph: """Apply clustering to a graphml string.""" - random = Random(seed) # noqa S311 + graph = nx.parse_graphml(graphml) for community_level, community_id, nodes in communities: if level == community_level: @@ -126,16 +131,17 @@ def apply_clustering( for node_degree in graph.degree: graph.nodes[str(node_degree[0])]["degree"] = int(node_degree[1]) - # add node uuid and incremental record id (a human readable id used as reference in the final report) + # Generate a unique ID for each entitiy and incremental record id (a human readable id used as reference in the final report) for index, node in enumerate(graph.nodes()): graph.nodes[node]["human_readable_id"] = index - graph.nodes[node]["id"] = str(gen_uuid(random)) + graph.nodes[node]["id"] = generate_entity_id(node) # add ids to edges for index, edge in enumerate(graph.edges()): - graph.edges[edge]["id"] = str(gen_uuid(random)) graph.edges[edge]["human_readable_id"] = index graph.edges[edge]["level"] = level + graph.edges[edge]["id"] = generate_entity_id(f"{edge[0]}:{edge[1]}") + return graph From 71263a4eeb23b58e2ef13209db797136e2ba9bb1 Mon Sep 17 00:00:00 2001 From: gbarroutlook Date: Wed, 4 Sep 2024 13:29:43 -0700 Subject: [PATCH 57/87] Add graphdb calls directly where relationships are filtered (#31) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add graphdb calls directly where relationships are filtered * Add function to perform graphdb queries for relationships --------- Co-authored-by: Guillermo Salvador Barrón Sánchez Co-authored-by: logomachic --- common/graph_db_client.py | 6 +- .../index/context_switch/contextSwitcher.py | 13 ++- graphrag/query/cli.py | 8 +- .../query/context_builder/local_context.py | 8 +- graphrag/query/factories.py | 5 ++ .../query/input/retrieval/relationships.py | 79 ++++++++++++++----- .../local_search/mixed_context.py | 12 ++- 7 files changed, 96 insertions(+), 35 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index 9a5318c97c..dc3b9d0e3a 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -103,13 +103,13 @@ def write_vertices(self,data: pd.DataFrame)->None: "prop_partition_key": "entities", "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), - "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), + "prop_text_unit_ids":json.dumps(row.text_unit_ids.tolist() if row.text_unit_ids is not None else []), }, ) time.sleep(5) - def write_edges(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000-000000000000")->None: + def write_edges(self,data: pd.DataFrame)->None: for row in data.itertuples(): if self.element_exists("g.E()",row.id): continue @@ -134,7 +134,7 @@ def write_edges(self,data: pd.DataFrame,context_id:str="00000000-0000-0000-0000- "prop_source_id": row.source, "prop_target_id": row.target, "prop_weight": row.weight, - "prop_text_unit_ids":json.dumps(row.text_unit_ids if row.text_unit_ids is not None else []), + "prop_text_unit_ids":json.dumps(row.text_unit_ids.tolist() if row.text_unit_ids is not None else []), "prop_description": row.description, "prop_id": row.id, "prop_human_readable_id": row.human_readable_id, diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index 0ebf5845ae..b38d299b27 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -225,13 +225,8 @@ def _read_config_parameters(root: str, config: str | None): if not optimized_search: final_covariates = read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet") - if config.graphdb.enabled: - final_relationships = graph_db_client.query_edges(context_id) - final_entities = graph_db_client.query_vertices(context_id) - else: - final_relationships = read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet") - final_entities = read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet") - + final_relationships = read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet") + final_entities = read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet") vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} @@ -253,6 +248,10 @@ def _read_config_parameters(root: str, config: str | None): description_embedding_store.load_entities(entities) if self.use_kusto_community_reports: description_embedding_store.load_reports(reports) + + if config.graphdb.enabled: + graph_db_client.write_vertices(final_entities) + graph_db_client.write_edges(final_relationships) def deactivate(self): """DeActivate the context.""" diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index c60c99524b..8ee745cb64 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -201,13 +201,11 @@ def run_local_search( final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) if config.graphdb.enabled: - final_relationships = pd.concat([final_relationships, graph_db_client.query_edges(context_id)]) final_entities = pd.concat([final_entities, graph_db_client.query_vertices(context_id)]) else: - final_relationships = pd.concat([final_relationships, read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) - + graph_db_client._client.close() vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} ) @@ -242,12 +240,14 @@ def run_local_search( reports=reports, text_units=read_indexer_text_units(final_text_units), entities=entities, - relationships=read_indexer_relationships(final_relationships), + relationships=[], covariates={"claims": covariates}, description_embedding_store=description_embedding_store, response_type=response_type, + context_id=context_id, is_optimized_search=optimized_search, use_kusto_community_reports=use_kusto_community_reports, + graphdb_config=config.graphdb, ) if optimized_search: diff --git a/graphrag/query/context_builder/local_context.py b/graphrag/query/context_builder/local_context.py index bd1790980b..b48bf0bf96 100644 --- a/graphrag/query/context_builder/local_context.py +++ b/graphrag/query/context_builder/local_context.py @@ -7,6 +7,7 @@ from typing import Any, cast import pandas as pd +from common.graph_db_client import GraphDBClient import tiktoken from graphrag.model import Covariate, Entity, Relationship @@ -164,7 +165,8 @@ def build_relationship_context( relationship_ranking_attribute: str = "rank", column_delimiter: str = "|", context_name: str = "Relationships", - is_optimized_search: bool = False + is_optimized_search: bool = False, + graphdb_client: GraphDBClient|None=None, ) -> tuple[str, pd.DataFrame]: """Prepare relationship data tables as context data for system prompt.""" selected_relationships = _filter_relationships( @@ -172,6 +174,7 @@ def build_relationship_context( relationships=relationships, top_k_relationships=top_k_relationships, relationship_ranking_attribute=relationship_ranking_attribute, + graphdb_client=graphdb_client, ) if len(selected_entities) == 0 or len(selected_relationships) == 0: @@ -236,6 +239,7 @@ def _filter_relationships( relationships: list[Relationship], top_k_relationships: int = 10, relationship_ranking_attribute: str = "rank", + graphdb_client: GraphDBClient|None=None, ) -> list[Relationship]: """Filter and sort relationships based on a set of selected entities and a ranking attribute.""" # First priority: in-network relationships (i.e. relationships between selected entities) @@ -243,6 +247,7 @@ def _filter_relationships( selected_entities=selected_entities, relationships=relationships, ranking_attribute=relationship_ranking_attribute, + graphdb_client=graphdb_client, ) # Second priority - out-of-network relationships @@ -251,6 +256,7 @@ def _filter_relationships( selected_entities=selected_entities, relationships=relationships, ranking_attribute=relationship_ranking_attribute, + graphdb_client=graphdb_client, ) if len(out_network_relationships) <= 1: return in_network_relationships + out_network_relationships diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index e9775f2bf2..c4320f380a 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -3,6 +3,7 @@ """Query Factory methods to support CLI.""" +from graphrag.config.models.graphdb_config import GraphDBConfig import tiktoken from azure.identity import DefaultAzureCredential, get_bearer_token_provider @@ -108,8 +109,10 @@ def get_local_search_engine( covariates: dict[str, list[Covariate]], response_type: str, description_embedding_store: BaseVectorStore, + context_id: str, is_optimized_search: bool = False, use_kusto_community_reports: bool = False, + graphdb_config: GraphDBConfig|None = None, ) -> LocalSearch: """Create a local search engine based on data + configuration.""" llm = get_llm(config) @@ -132,6 +135,8 @@ def get_local_search_engine( token_encoder=token_encoder, is_optimized_search= is_optimized_search, use_kusto_community_reports=use_kusto_community_reports, + graphdb_config=graphdb_config, + context_id=context_id, ), token_encoder=token_encoder, llm_params={ diff --git a/graphrag/query/input/retrieval/relationships.py b/graphrag/query/input/retrieval/relationships.py index 2ef5ba38cb..14ddf56180 100644 --- a/graphrag/query/input/retrieval/relationships.py +++ b/graphrag/query/input/retrieval/relationships.py @@ -7,22 +7,48 @@ import pandas as pd +from common.graph_db_client import GraphDBClient from graphrag.model import Entity, Relationship +from graphrag.query.input.loaders.dfs import read_relationships + +def get_relationships_from_graphdb(query:str,selected_entity_names:list[str],graphdb_client: GraphDBClient): + relationships_result=graphdb_client._client.submit( + message=query, + bindings={ + "prop_selected_entity_names": selected_entity_names, + } + ) + return read_relationships( + graphdb_client.result_to_df(relationships_result), + short_id_col="human_readable_id" + ) def get_in_network_relationships( selected_entities: list[Entity], relationships: list[Relationship], ranking_attribute: str = "rank", + graphdb_client: GraphDBClient|None=None, ) -> list[Relationship]: """Get all directed relationships between selected entities, sorted by ranking_attribute.""" selected_entity_names = [entity.title for entity in selected_entities] - selected_relationships = [ - relationship - for relationship in relationships - if relationship.source in selected_entity_names - and relationship.target in selected_entity_names - ] + if not graphdb_client: + selected_relationships = [ + relationship + for relationship in relationships + if relationship.source in selected_entity_names + and relationship.target in selected_entity_names + ] + else: + selected_relationships = get_relationships_from_graphdb( + query=( + "g.E()" + ".where(inV().has('name',within(prop_selected_entity_names)))" + ".where(outV().has('name',within(prop_selected_entity_names)))" + ), + selected_entity_names=selected_entity_names, + graphdb_client=graphdb_client + ) if len(selected_relationships) <= 1: return selected_relationships @@ -36,22 +62,37 @@ def get_out_network_relationships( selected_entities: list[Entity], relationships: list[Relationship], ranking_attribute: str = "rank", + graphdb_client: GraphDBClient|None=None, ) -> list[Relationship]: """Get relationships from selected entities to other entities that are not within the selected entities, sorted by ranking_attribute.""" selected_entity_names = [entity.title for entity in selected_entities] - source_relationships = [ - relationship - for relationship in relationships - if relationship.source in selected_entity_names - and relationship.target not in selected_entity_names - ] - target_relationships = [ - relationship - for relationship in relationships - if relationship.target in selected_entity_names - and relationship.source not in selected_entity_names - ] - selected_relationships = source_relationships + target_relationships + if not graphdb_client: + source_relationships = [ + relationship + for relationship in relationships + if relationship.source in selected_entity_names + and relationship.target not in selected_entity_names + ] + target_relationships = [ + relationship + for relationship in relationships + if relationship.target in selected_entity_names + and relationship.source not in selected_entity_names + ] + selected_relationships = source_relationships + target_relationships + else: + selected_relationships = get_relationships_from_graphdb( + query=( + "g.E().union(" + "__.where(outV().has('name',without(prop_selected_entity_names)))" + ".where(inV().has('name',within(prop_selected_entity_names)))," + "__.where(inV().has('name',without(prop_selected_entity_names)))" + ".where(outV().has('name',within(prop_selected_entity_names)))" + ")" + ), + selected_entity_names= selected_entity_names, + graphdb_client=graphdb_client + ) return sort_relationships_by_ranking_attribute( selected_relationships, selected_entities, ranking_attribute ) diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index 988d290848..af4c63d55b 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -6,6 +6,8 @@ from typing import Any import pandas as pd +from common.graph_db_client import GraphDBClient +from graphrag.config.models.graphdb_config import GraphDBConfig import tiktoken from graphrag.model import ( @@ -64,6 +66,8 @@ def __init__( embedding_vectorstore_key: str = EntityVectorStoreKey.ID, is_optimized_search: bool = False, use_kusto_community_reports: bool = False, + graphdb_config: GraphDBConfig|None = None, + context_id:str = None, ): if community_reports is None: community_reports = [] @@ -88,6 +92,8 @@ def __init__( self.embedding_vectorstore_key = embedding_vectorstore_key self.is_optimized_search = is_optimized_search self.use_kusto_community_reports = use_kusto_community_reports + self.graphdb_config = graphdb_config + self.context_id = context_id def filter_by_entity_keys(self, entity_keys: list[int] | list[str]): """Filter entity text embeddings by entity keys.""" @@ -433,6 +439,7 @@ def _build_local_context( final_context_data = {} # gradually add entities and associated metadata to the context until we reach limit + graphdb_client=GraphDBClient(self.graphdb_config,self.context_id) if (self.graphdb_config and self.graphdb_config.enabled) else None for entity in selected_entities: current_context = [] current_context_data = {} @@ -452,7 +459,8 @@ def _build_local_context( include_relationship_weight=include_relationship_weight, relationship_ranking_attribute=relationship_ranking_attribute, context_name="Relationships", - is_optimized_search=is_optimized_search + is_optimized_search=is_optimized_search, + graphdb_client=graphdb_client, ) current_context.append(relationship_context) current_context_data["relationships"] = relationship_context_data @@ -484,6 +492,8 @@ def _build_local_context( final_context_data = current_context_data # attach entity context to final context + if graphdb_client: + graphdb_client._client.close() final_context_text = entity_context + "\n\n" + "\n\n".join(final_context) final_context_data["entities"] = entity_context_data From c3330092fc50415a8c518477d73f4452b1f91ea2 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Thu, 5 Sep 2024 11:44:28 -0700 Subject: [PATCH 58/87] Adding graphdb into for-loop per data_path of context b/c it should happen per path after the concats were removed. --- graphrag/index/context_switch/contextSwitcher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index b38d299b27..737828517b 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -248,10 +248,10 @@ def _read_config_parameters(root: str, config: str | None): description_embedding_store.load_entities(entities) if self.use_kusto_community_reports: description_embedding_store.load_reports(reports) - - if config.graphdb.enabled: - graph_db_client.write_vertices(final_entities) - graph_db_client.write_edges(final_relationships) + + if config.graphdb.enabled: + graph_db_client.write_vertices(final_entities) + graph_db_client.write_edges(final_relationships) def deactivate(self): """DeActivate the context.""" From 1b66bd39650398fe2bc4ddbc5647ad8f8dfd89a0 Mon Sep 17 00:00:00 2001 From: Amritpal Singh Date: Thu, 5 Sep 2024 12:29:34 -0700 Subject: [PATCH 59/87] logs on file & stdout + unbuffered logs --- graphrag/index/cli.py | 8 +++++--- pyproject.toml | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 63fa9c190b..fb1a5ed4c2 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -357,11 +357,13 @@ def _enable_logging(root_dir: str, run_id: str, verbose: bool) -> None: logging_file.parent.mkdir(parents=True, exist_ok=True) logging_file.touch(exist_ok=True) - + handler = logging.StreamHandler(stream=sys.stdout) + fileHandler = logging.FileHandler(logging_file, mode="a") logging.basicConfig( - filename=str(logging_file), - filemode="a", + #filename=str(logging_file), + #filemode="a", format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", datefmt="%H:%M:%S", level=logging.DEBUG if verbose else logging.INFO, + handlers=[handler, fileHandler] ) diff --git a/pyproject.toml b/pyproject.toml index 4ae07d7552..e511454ff4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,8 +128,8 @@ _test_all = "coverage run -m pytest ./tests" test_unit = "pytest ./tests/unit" test_integration = "pytest ./tests/integration" test_smoke = "pytest ./tests/smoke" -index = "python -m graphrag.index" -query = "python -m graphrag.query" +index = "python -u -m graphrag.index" +query = "python -u -m graphrag.query" prompt_tune = "python -m graphrag.prompt_tune" # Pass in a test pattern test_only = "pytest -s -k" From 44374d9c6179e4b17f2b2a82eaefbe64331a82ef Mon Sep 17 00:00:00 2001 From: sirus-ms Date: Fri, 6 Sep 2024 14:36:39 -0700 Subject: [PATCH 60/87] Fix cli when graphdb is not enabled. (#36) --- graphrag/query/cli.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 8ee745cb64..f2b14ef55b 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -202,10 +202,11 @@ def run_local_search( if config.graphdb.enabled: final_entities = pd.concat([final_entities, graph_db_client.query_vertices(context_id)]) + graph_db_client._client.close() else: final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) - graph_db_client._client.close() + vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} ) From 72a866bc78f963af89427a6cc35dbdfb8f75d47d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Salvador=20Barr=C3=B3n=20S=C3=A1nchez?= Date: Fri, 6 Sep 2024 11:15:34 -0700 Subject: [PATCH 61/87] Add graphdb parameters for local emulator support --- common/graph_db_client.py | 2 +- graphrag/config/create_graphrag_config.py | 2 ++ graphrag/config/models/graphdb_config.py | 10 ++++++++++ graphrag/index/context_switch/contextSwitcher.py | 5 ++++- graphrag/index/init_content.py | 2 ++ graphrag/query/input/retrieval/relationships.py | 3 +++ 6 files changed, 22 insertions(+), 2 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index dc3b9d0e3a..704748039a 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -16,7 +16,7 @@ class GraphDBClient: def __init__(self,graph_db_params: GraphDBConfig|None,context_id: str|None): self.username_prefix=graph_db_params.username self._client=client.Client( - url=f"wss://{graph_db_params.account_name}.gremlin.cosmos.azure.com:443/", + url=f"{graph_db_params.gremlin_url}", traversal_source="g", username=self.username_prefix+"-contextid-"+context_id, password=f"{graph_db_params.account_key}", diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 8477916736..34b953345e 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -560,6 +560,8 @@ def hydrate_parallelization_params( account_key=reader.str("account_key") or None, username=reader.str("username") or None, enabled=reader.bool("enabled") or False, + cosmos_url=reader.str("cosmos_url") or None, + gremlin_url=reader.str("gremlin_url") or None, ) encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL diff --git a/graphrag/config/models/graphdb_config.py b/graphrag/config/models/graphdb_config.py index 9be8c8751c..8ee0f9d276 100644 --- a/graphrag/config/models/graphdb_config.py +++ b/graphrag/config/models/graphdb_config.py @@ -24,4 +24,14 @@ class GraphDBConfig(BaseModel): enabled: bool = Field( description="Flag to enable querying into graphdb", default=False + ) + + cosmos_url: str|None = Field( + description="Cosmos account url", + default=None, + ) + + gremlin_url: str|None = Field( + description="Gremlin db url", + default=None, ) \ No newline at end of file diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index 737828517b..a864988a65 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -199,7 +199,7 @@ def _read_config_parameters(root: str, config: str | None): final_covariates = pd.DataFrame() if config.graphdb.enabled: cosmos_client = CosmosClient( - f"https://{config.graphdb.account_name}.documents.azure.com:443/", + f"{config.graphdb.cosmos_url}", f"{config.graphdb.account_key}", ) database_name = config.graphdb.username.split("/")[2] @@ -252,6 +252,9 @@ def _read_config_parameters(root: str, config: str | None): if config.graphdb.enabled: graph_db_client.write_vertices(final_entities) graph_db_client.write_edges(final_relationships) + + if config.graphdb.enabled: + graph_db_client._client.close() def deactivate(self): """DeActivate the context.""" diff --git a/graphrag/index/init_content.py b/graphrag/index/init_content.py index 5214abcc85..8f5982f8f7 100644 --- a/graphrag/index/init_content.py +++ b/graphrag/index/init_content.py @@ -167,6 +167,8 @@ account_key: '' username: '' enabled: false + cosmos_url: '' + gremlin_url: '' """ INIT_DOTENV = """ diff --git a/graphrag/query/input/retrieval/relationships.py b/graphrag/query/input/retrieval/relationships.py index 14ddf56180..7be258d23c 100644 --- a/graphrag/query/input/retrieval/relationships.py +++ b/graphrag/query/input/retrieval/relationships.py @@ -3,6 +3,7 @@ """Util functions to retrieve relationships from a collection.""" +import time from typing import Any, cast import pandas as pd @@ -19,6 +20,8 @@ def get_relationships_from_graphdb(query:str,selected_entity_names:list[str],gra "prop_selected_entity_names": selected_entity_names, } ) + time.sleep(5) + print(graphdb_client.result_to_df(relationships_result)) return read_relationships( graphdb_client.result_to_df(relationships_result), short_id_col="human_readable_id" From 2baa18dbc3eaf2b056465eed950ee8e36fb9044d Mon Sep 17 00:00:00 2001 From: amritpalms Date: Sat, 7 Sep 2024 18:39:00 -0700 Subject: [PATCH 62/87] incline to run in azure --- common/graph_db_client.py | 9 ++++++++- graphrag/common/storage/blob_pipeline_storage.py | 2 +- graphrag/index/cli.py | 1 + graphrag/index/context_switch/contextSwitcher.py | 6 +++--- graphrag/index/input/load_input.py | 1 - graphrag/query/cli.py | 7 ++++--- graphrag/vector_stores/kusto.py | 16 +++++++++++----- 7 files changed, 28 insertions(+), 14 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index dc3b9d0e3a..d451f2a662 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -7,19 +7,26 @@ import ast from gremlin_python.driver import client, serializer +from azure.identity import ManagedIdentityCredential import time import os import json +# Azure Cosmos DB Gremlin Endpoint and other constants +COSMOS_DB_SCOPE = "https://cosmos.azure.com/.default" # The scope for Cosmos DB class GraphDBClient: def __init__(self,graph_db_params: GraphDBConfig|None,context_id: str|None): self.username_prefix=graph_db_params.username + token = f"{graph_db_params.account_key}" + if(os.environ.get("ENVIRONMENT") == "AZURE"): + credential = ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3") + token = credential.get_token(COSMOS_DB_SCOPE) self._client=client.Client( url=f"wss://{graph_db_params.account_name}.gremlin.cosmos.azure.com:443/", traversal_source="g", username=self.username_prefix+"-contextid-"+context_id, - password=f"{graph_db_params.account_key}", + password=token", message_serializer=serializer.GraphSONSerializersV2d0(), ) diff --git a/graphrag/common/storage/blob_pipeline_storage.py b/graphrag/common/storage/blob_pipeline_storage.py index 6acc761e5c..568ec89bcc 100644 --- a/graphrag/common/storage/blob_pipeline_storage.py +++ b/graphrag/common/storage/blob_pipeline_storage.py @@ -50,7 +50,7 @@ def __init__( self._blob_service_client = BlobServiceClient( account_url=storage_account_blob_url, - credential=DefaultAzureCredential(), + credential=DefaultAzureCredential(managed_identity_client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), ) self._encoding = encoding or "utf-8" self._container_name = container_name diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index fb1a5ed4c2..795d4e22fb 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -200,6 +200,7 @@ def _switch_context(root: str, config: str, use_kusto_community_reports: bool) -> None: """Switch the context to the given context.""" reporter.info(f"Switching context to {context_id} using operation {context_operation}") + logging.info("Switching context to {context_id}") from graphrag.index.context_switch.contextSwitcher import ContextSwitcher context_switcher = ContextSwitcher( root_dir=root, diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index 737828517b..24f62eb5b8 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -29,6 +29,7 @@ from azure.cosmos import CosmosClient, PartitionKey from graphrag.vector_stores.base import BaseVectorStore from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType +import logging class ContextSwitcher: """ContextSwitcher class definition.""" @@ -47,6 +48,7 @@ def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter, self.optimized_search=optimized_search self.community_level = community_level self.use_kusto_community_reports = use_kusto_community_reports + logging.info("ContextSwitcher initialized") def setup_vector_store(self, config_args: dict | None = None,) -> BaseVectorStore: @@ -181,13 +183,11 @@ def _read_config_parameters(root: str, config: str | None): ValueError("Memory storage is not supported") if(config.storage.type == StorageType.blob): if(config.storage.container_name is not None): - input_storage_client: PipelineStorage = BlobPipelineStorage(config.storage.connection_string, config.storage.container_name) - output_storage_client: PipelineStorage = BlobPipelineStorage(config.storage.connection_string, config.storage.container_name) + input_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, container_name=config.storage.container_name, storage_account_blob_url=config.storage.storage_account_blob_url) else: ValueError("Storage type is Blob but container name is invalid") if(config.storage.type == StorageType.file): input_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) - output_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) data_paths = [] data_paths = get_files_by_contextid(config, context_id) diff --git a/graphrag/index/input/load_input.py b/graphrag/index/input/load_input.py index f4aaa5cf5a..cd1c655247 100644 --- a/graphrag/index/input/load_input.py +++ b/graphrag/index/input/load_input.py @@ -60,7 +60,6 @@ async def load_input( connection_string=config.connection_string, storage_account_blob_url=config.storage_account_blob_url, container_name=config.container_name, - path_prefix=config.base_dir, ) case InputType.file: log.info("using file storage for input") diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index f2b14ef55b..488e95c590 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -171,8 +171,8 @@ def run_local_search( ValueError("Memory storage is not supported") if(config.storage.type == StorageType.blob): if(config.storage.container_name is not None): - input_storage_client: PipelineStorage = BlobPipelineStorage(config.storage.connection_string, config.storage.container_name) - output_storage_client: PipelineStorage = BlobPipelineStorage(config.storage.connection_string, config.storage.container_name) + input_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, container_name=config.storage.container_name, storage_account_blob_url=config.storage.storage_account_blob_url) + output_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, container_name=config.storage.container_name, storage_account_blob_url=config.storage.storage_account_blob_url) else: ValueError("Storage type is Blob but container name is invalid") if(config.storage.type == StorageType.file): @@ -206,7 +206,8 @@ def run_local_search( else: final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) - + if config.graphdb.enabled: + graph_db_client._client.close() vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} ) diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 71113c5d04..468c5e5428 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -2,7 +2,7 @@ # Licensed under the MIT License """The Azure Kusto vector storage implementation package.""" - +import os import typing from azure.kusto.data import KustoClient, KustoConnectionStringBuilder from azure.kusto.data.helpers import dataframe_from_result_table @@ -56,10 +56,16 @@ def connect(self, **kwargs: Any) -> Any: client_id = kwargs.get("client_id") client_secret = kwargs.get("client_secret") authority_id = kwargs.get("authority_id") - - kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( - str(cluster), str(client_id), str(client_secret), str(authority_id) - ) + env = os.environ.get("ENVIRONMENT") + if(env == "AZURE"): + kcsb = KustoConnectionStringBuilder.with_aad_managed_service_identity_authentication( + str(cluster), client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3" + ) + elif(env == "DEVELOPMENT"): + kcsb = KustoConnectionStringBuilder.with_aad_device_authentication(str(cluster)) + else: + kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( + str(cluster), str(client_id), str(client_secret), str(authority_id)) self.client = KustoClient(kcsb) self.database = database From 0e1df62715a02791a4e6478422330147e77b8dd1 Mon Sep 17 00:00:00 2001 From: amritpalms Date: Sat, 7 Sep 2024 18:53:17 -0700 Subject: [PATCH 63/87] commenting managed identity code --- common/graph_db_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index d451f2a662..0b5cb19f64 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -19,9 +19,9 @@ class GraphDBClient: def __init__(self,graph_db_params: GraphDBConfig|None,context_id: str|None): self.username_prefix=graph_db_params.username token = f"{graph_db_params.account_key}" - if(os.environ.get("ENVIRONMENT") == "AZURE"): - credential = ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3") - token = credential.get_token(COSMOS_DB_SCOPE) + #if(os.environ.get("ENVIRONMENT") == "AZURE"): + # credential = ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3") + # token = credential.get_token(COSMOS_DB_SCOPE) self._client=client.Client( url=f"wss://{graph_db_params.account_name}.gremlin.cosmos.azure.com:443/", traversal_source="g", From a92787efa3f761c538129d6385ec8c6ddc807da2 Mon Sep 17 00:00:00 2001 From: amritpalms Date: Sat, 7 Sep 2024 19:02:04 -0700 Subject: [PATCH 64/87] minor fix --- common/graph_db_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index 0b5cb19f64..eb5a96e84d 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -26,7 +26,7 @@ def __init__(self,graph_db_params: GraphDBConfig|None,context_id: str|None): url=f"wss://{graph_db_params.account_name}.gremlin.cosmos.azure.com:443/", traversal_source="g", username=self.username_prefix+"-contextid-"+context_id, - password=token", + password=token, message_serializer=serializer.GraphSONSerializersV2d0(), ) From 6a13d553b162d1121d8f0a316829d239b333b2b2 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Mon, 9 Sep 2024 13:32:24 -0700 Subject: [PATCH 65/87] fix kusto cli --- graphrag/query/cli.py | 131 +++++++++++++++++++++++++----------------- 1 file changed, 77 insertions(+), 54 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 488e95c590..48b5650e53 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -164,61 +164,92 @@ def run_local_search( data_dir, root_dir, config_dir ) - # TODO: loading stage here must be only limited to default lancedb. - # for the POC purpose input artifacts blob, output artifacts blob and input query blob storage are going to same. - if(config.storage.type == StorageType.memory): - ValueError("Memory storage is not supported") + + vector_store_args = ( + config.embeddings.vector_store if config.embeddings.vector_store else {} + ) + + reporter.info(f"Vector Store Args: {vector_store_args}") + vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) + + entities=[] + text_units=[] + covariates=[] + reports=[] + final_relationships=[] + if(config.storage.type == StorageType.blob): if(config.storage.container_name is not None): - input_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, container_name=config.storage.container_name, storage_account_blob_url=config.storage.storage_account_blob_url) - output_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, container_name=config.storage.container_name, storage_account_blob_url=config.storage.storage_account_blob_url) + output_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, + container_name=config.storage.container_name, + storage_account_blob_url=config.storage.storage_account_blob_url) else: ValueError("Storage type is Blob but container name is invalid") - if(config.storage.type == StorageType.file): - input_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + elif(config.storage.type == StorageType.file): output_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) - data_paths = [] - data_paths = get_files_by_contextid(config, context_id) - final_nodes = pd.DataFrame() - final_community_reports = pd.DataFrame() - final_text_units = pd.DataFrame() - final_relationships = pd.DataFrame() - final_entities = pd.DataFrame() - final_covariates = pd.DataFrame() - if config.graphdb.enabled: - graph_db_client = GraphDBClient(config.graphdb,context_id) - for data_path in data_paths: - #check from the config for the ouptut storage type and then read the data from the storage. - #GraphDB: we may need to make change below to read nodes data from Graph DB - final_nodes = pd.concat([final_nodes, read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet")]) - final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto - final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. - if not optimized_search: - final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) + ##### LEGACY ####################### + + if vector_store_type == VectorStoreType.LanceDB: + # for the POC purpose input artifacts blob, output artifacts blob and input query blob storage are going to same. + if(config.storage.type == StorageType.memory): + ValueError("Memory storage is not supported") + if(config.storage.type == StorageType.blob): + if(config.storage.container_name is not None): + input_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, + container_name=config.storage.container_name, + storage_account_blob_url=config.storage.storage_account_blob_url) + else: + ValueError("Storage type is Blob but container name is invalid") + if(config.storage.type == StorageType.file): + input_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + + + data_paths = [] + data_paths = get_files_by_contextid(config, context_id) + final_nodes = pd.DataFrame() + final_community_reports = pd.DataFrame() + final_text_units = pd.DataFrame() + final_relationships = pd.DataFrame() + final_entities = pd.DataFrame() + final_covariates = pd.DataFrame() + + for data_path in data_paths: + #check from the config for the ouptut storage type and then read the data from the storage. + + #GraphDB: we may need to make change below to read nodes data from Graph DB + final_nodes = pd.concat([final_nodes, read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet")]) + final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto + final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. + final_relationships = pd.concat([final_text_units,read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) + + if not optimized_search: + final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) - if config.graphdb.enabled: - final_entities = pd.concat([final_entities, graph_db_client.query_vertices(context_id)]) - graph_db_client._client.close() - else: final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) - if config.graphdb.enabled: - graph_db_client._client.close() - vector_store_args = ( - config.embeddings.vector_store if config.embeddings.vector_store else {} - ) + ############# End of for loop - reporter.info(f"Vector Store Args: {vector_store_args}") - vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) + entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. + reports=read_indexer_reports( + final_community_reports, final_nodes, community_level + ) + + covariates = ( + read_indexer_covariates(final_covariates) + if final_covariates.empty is False + else [] + ) + text_units=read_indexer_text_units(final_text_units) + + elif not use_kusto_community_reports: + print("\n\n[!] WARNING: Passing empty reports.\n\n") + + ######################################################################################## - entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. - reports=read_indexer_reports( - final_community_reports, final_nodes, community_level - ) description_embedding_store = __get_embedding_description_store( entities=entities, vector_store_type=vector_store_type, @@ -226,23 +257,15 @@ def run_local_search( context_id=context_id, ) - covariates = ( - read_indexer_covariates(final_covariates) - if final_covariates.empty is False - else [] - ) - - if(isinstance(description_embedding_store, KustoVectorStore)): - entities = [] - if use_kusto_community_reports: - reports = [] - + ''' + *** If KUSTO is enabled, both entities and final_relationships must be empty. + ''' search_engine = get_local_search_engine( config, reports=reports, - text_units=read_indexer_text_units(final_text_units), + text_units=text_units, entities=entities, - relationships=[], + relationships=final_relationships, covariates={"claims": covariates}, description_embedding_store=description_embedding_store, response_type=response_type, From 25aad8ce8bdbb9333b49440e01c23d2b16d67662 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Mon, 9 Sep 2024 17:07:01 -0700 Subject: [PATCH 66/87] fix legacy temporarily disable community reports for kusto --- graphrag/query/cli.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 48b5650e53..7b3425b480 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -224,7 +224,7 @@ def run_local_search( final_nodes = pd.concat([final_nodes, read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet")]) final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. - final_relationships = pd.concat([final_text_units,read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) + final_relationships = pd.concat([final_relationships,read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) if not optimized_search: final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) @@ -238,6 +238,8 @@ def run_local_search( final_community_reports, final_nodes, community_level ) + final_relationships=read_indexer_relationships(final_relationships) + covariates = ( read_indexer_covariates(final_covariates) if final_covariates.empty is False @@ -245,11 +247,12 @@ def run_local_search( ) text_units=read_indexer_text_units(final_text_units) - elif not use_kusto_community_reports: - print("\n\n[!] WARNING: Passing empty reports.\n\n") ######################################################################################## + if use_kusto_community_reports: + ValueError("Using community reports is not supported.") + description_embedding_store = __get_embedding_description_store( entities=entities, vector_store_type=vector_store_type, From 51c77209aef2d5e61a1853979eb0f7e4dd8d3d45 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Mon, 9 Sep 2024 22:52:24 -0700 Subject: [PATCH 67/87] Implement deactivation switch --- .../index/context_switch/contextSwitcher.py | 110 ++++++++++-------- graphrag/vector_stores/azure_ai_search.py | 5 +- graphrag/vector_stores/base.py | 6 +- graphrag/vector_stores/lancedb.py | 5 +- 4 files changed, 73 insertions(+), 53 deletions(-) diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index 254efad514..d770cd2719 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -50,27 +50,33 @@ def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter, self.use_kusto_community_reports = use_kusto_community_reports logging.info("ContextSwitcher initialized") - def setup_vector_store(self, - config_args: dict | None = None,) -> BaseVectorStore: + def get_embedding_store(self,config_args): """Set up the vector store and return it.""" if not config_args: - config_args = {} + config_args = {} collection_name = config_args.get( - "query_collection_name", "entity_description_embeddings" + "query_collection_name", "entity_description_embeddings" ) collection_name += "_" + self.context_id config_args.update({"collection_name": collection_name}) vector_name = config_args.get( - "vector_search_column", "description_embedding" + "vector_search_column", "description_embedding" ) config_args.update({"vector_name": vector_name}) config_args.update({"reports_name": f"reports_{self.context_id}"}) - description_embedding_store = VectorStoreFactory.get_vector_store( - vector_store_type=VectorStoreType.Kusto, kwargs=config_args + return VectorStoreFactory.get_vector_store( + vector_store_type=VectorStoreType.Kusto, kwargs=config_args ) + + + + def setup_vector_store(self, + config_args: dict | None = None,) -> BaseVectorStore: + + description_embedding_store = self.get_embedding_store(config_args) description_embedding_store.connect(**config_args) description_embedding_store.setup_entities() @@ -79,6 +85,43 @@ def setup_vector_store(self, return description_embedding_store + def _read_config_parameters(self,root: str, config: str | None): + reporter=self.reporter + _root = Path(root) + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) + if not settings_yaml.exists(): + settings_yaml = _root / "settings.yml" + + if settings_yaml.exists(): + reporter.info(f"Reading settings from {settings_yaml}") + with settings_yaml.open( + "rb", + ) as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) + if settings_json.exists(): + reporter.info(f"Reading settings from {settings_json}") + with settings_json.open("rb") as file: + import json + + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + reporter.info("Reading settings from environment variables") + return create_graphrag_config(root_dir=root) + def activate(self): """Activate the context.""" #1. read the context id to fileId mapping. @@ -131,46 +174,7 @@ def _create_graphrag_config( config_dir: str | None, ) -> GraphRagConfig: """Create a GraphRag configuration.""" - return _read_config_parameters(root or "./", config_dir) - - - def _read_config_parameters(root: str, config: str | None): - _root = Path(root) - settings_yaml = ( - Path(config) - if config and Path(config).suffix in [".yaml", ".yml"] - else _root / "settings.yaml" - ) - if not settings_yaml.exists(): - settings_yaml = _root / "settings.yml" - - if settings_yaml.exists(): - reporter.info(f"Reading settings from {settings_yaml}") - with settings_yaml.open( - "rb", - ) as file: - import yaml - - data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) - return create_graphrag_config(data, root) - - settings_json = ( - Path(config) - if config and Path(config).suffix == ".json" - else _root / "settings.json" - ) - if settings_json.exists(): - reporter.info(f"Reading settings from {settings_json}") - with settings_json.open("rb") as file: - import json - - data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) - return create_graphrag_config(data, root) - - reporter.info("Reading settings from environment variables") - return create_graphrag_config(root_dir=root) - - + return self._read_config_parameters(root or "./", config_dir) ################################################################################ @@ -197,6 +201,8 @@ def _read_config_parameters(root: str, config: str | None): final_relationships = pd.DataFrame() final_entities = pd.DataFrame() final_covariates = pd.DataFrame() + graph_db_client=None + if config.graphdb.enabled: cosmos_client = CosmosClient( f"{config.graphdb.cosmos_url}", @@ -252,11 +258,15 @@ def _read_config_parameters(root: str, config: str | None): if config.graphdb.enabled: graph_db_client.write_vertices(final_entities) graph_db_client.write_edges(final_relationships) - + if config.graphdb.enabled: graph_db_client._client.close() def deactivate(self): """DeActivate the context.""" - #1. Delete all the data for a given context id. - self.set_ctx_activation(0) \ No newline at end of file + + config=self._read_config_parameters(self.root_dir or "./",self.config_dir) + config_args = config.embeddings.vector_store + description_embedding_store = self.get_embedding_store(config_args) + description_embedding_store.connect(**config_args) + description_embedding_store.unload_entities() \ No newline at end of file diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index 7e8de25f38..7be75d44a5 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -212,4 +212,7 @@ def setup_entities(self) -> None: raise NotImplementedError("Setting up entities is not supported for Azure AI Search") def setup_reports(self) -> None: - raise NotImplementedError("Setting up reports is not supported for Azure AI Search") \ No newline at end of file + raise NotImplementedError("Setting up reports is not supported for Azure AI Search") + + def unload_entities(self) -> None: + raise NotImplementedError("unload_entities(): Unsupported for this vector store.") \ No newline at end of file diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index f198b5a34b..a968f56a6d 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -112,4 +112,8 @@ def setup_entities(self) -> None: @abstractmethod def setup_reports(self) -> None: - """Setup the reports in the vector-store.""" \ No newline at end of file + """Setup the reports in the vector-store.""" + + @abstractmethod + def unload_entities(self) -> None: + """Remove context from the databases.""" \ No newline at end of file diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index f9e0e40528..4bff541703 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -139,4 +139,7 @@ def setup_entities(self) -> None: raise NotImplementedError("Setting up entities is not supported for LanceDB") def setup_reports(self) -> None: - raise NotImplementedError("Setting up community reports is not supported for LanceDB") \ No newline at end of file + raise NotImplementedError("Setting up community reports is not supported for LanceDB") + + def unload_entities(self) -> None: + raise NotImplementedError("unload_entities(): Unsupported for this vector store.") \ No newline at end of file From 544f4b7ff3b0fb13b4701c785229abb70797bded Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Tue, 10 Sep 2024 09:48:44 -0700 Subject: [PATCH 68/87] Add deactivation switch for graphdb --- common/graph_db_client.py | 3 +++ graphrag/index/context_switch/contextSwitcher.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/common/graph_db_client.py b/common/graph_db_client.py index a7830ddd3c..ab4503f01c 100644 --- a/common/graph_db_client.py +++ b/common/graph_db_client.py @@ -50,6 +50,9 @@ def result_to_df(self,result) -> pd.DataFrame: df = pd.DataFrame(json_data) return df + def remove_graph(self): + self._client.submit(message=("g.V().drop()")) + def query_vertices(self,context_id:str) -> pd.DataFrame: result = self._client.submit( message=( diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index d770cd2719..f43c4e53dd 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -269,4 +269,8 @@ def deactivate(self): config_args = config.embeddings.vector_store description_embedding_store = self.get_embedding_store(config_args) description_embedding_store.connect(**config_args) - description_embedding_store.unload_entities() \ No newline at end of file + description_embedding_store.unload_entities() + + if config.graphdb.enabled: + g_client=GraphDBClient(config.graphdb,self.context_id) + g_client.remove_graph() \ No newline at end of file From a50c65c0cac7a450139888cc013e64c31e63617e Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Thu, 12 Sep 2024 10:49:46 -0700 Subject: [PATCH 69/87] Allowing multiple files to be indexed. --- graphrag/index/input/text.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/graphrag/index/input/text.py b/graphrag/index/input/text.py index 51a5da9d82..3a3e15cbd9 100644 --- a/graphrag/index/input/text.py +++ b/graphrag/index/input/text.py @@ -55,11 +55,7 @@ async def load_file( if len(files) == 0: msg = f"No text files found in {config.base_dir}" raise ValueError(msg) - - if len(files) > 1: - msg = f"found more than 1 files in base dir {config.base_dir}" - raise ValueError(msg) - + found_files = f"found text files from {config.base_dir}, found {files}" log.info(found_files) From 7232a68ed67a57449028aa8cae2482bdd44d0242 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Mon, 16 Sep 2024 12:34:33 -0700 Subject: [PATCH 70/87] Initial code for the different query paths. --- graphrag/query/__main__.py | 8 +++ graphrag/query/cli.py | 103 +++++++++++++++++++++++++++++++++++-- 2 files changed, 108 insertions(+), 3 deletions(-) diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index e0a5e9f583..cbaab1fcb4 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -89,6 +89,13 @@ def __str__(self): action="store_true", ) + parser.add_argument( + "--paths", + help="Different paths for the query", + action="store_true", + default=0, # Default to normal graphrag search + ) + parser.add_argument( "query", nargs=1, @@ -110,6 +117,7 @@ def __str__(self): args.query[0], optimized_search=args.optimized_search, use_kusto_community_reports=args.use_kusto_community_reports, + paths=paths, ) case SearchType.GLOBAL: run_global_search( diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 7b3425b480..3a96c6a823 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -8,6 +8,9 @@ from pathlib import Path from typing import cast from io import BytesIO + +from datashaper import VerbCallbacks +from graphrag.common.progress.rich import RichProgressReporter from graphrag.common.storage import PipelineStorage, BlobPipelineStorage, FilePipelineStorage from graphrag.common.utils.context_utils import get_files_by_contextid from graphrag.config.enums import StorageType @@ -20,6 +23,8 @@ GraphRagConfig, ) from graphrag.common.progress import PrintProgressReporter +from graphrag.index.verbs.entities.extraction.strategies.graph_intelligence.run_graph_intelligence import run_gi +from graphrag.index.verbs.entities.extraction.strategies.typing import Document from graphrag.model.entity import Entity from graphrag.query.input.loaders.dfs import ( store_entity_semantic_embeddings, @@ -147,8 +152,7 @@ def run_global_search( reporter.success(f"Global Search Response: {result.response}") return result.response - -def run_local_search( +def path0( config_dir: str | None, data_dir: str | None, root_dir: str | None, @@ -158,7 +162,8 @@ def run_local_search( query: str, optimized_search: bool = False, use_kusto_community_reports: bool = False, -): + ): + """Run a local search with the given query.""" data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config_dir @@ -287,6 +292,98 @@ def run_local_search( reporter.success(f"Local Search Response: {result.response}") return result.response +def path1( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + ValueError("Not implemented") + +def path2( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + """Path 2 + Find all the emails sent to trader by Tim Belden + a. Query -> LLM -> Entity Extracted -> 5 entities -> Set A [TimBelden1] + b. Query -> LLM -> Embeddings -> Y [x1..... Xn] + c. Run the query on Kusto for embedding Y [x1.....xn] for entitYid in [TimBelden1] + 4. Get the text units and get the response""" + data_dir, root_dir, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + + # Populate args with dict of arguments for the LLM + args = {} + args['api_key'] = config.llm.api_key + args['type'] = config.llm.type + args['model'] = config.llm.model + args['model_supports_json'] = config.llm.model_supports_json + args['api_base'] = config.llm.api_base + args['api_version'] = config.llm.api_version + args['deployment_name'] = config.llm.deployment_name + llmm = {} + llmm['llm'] = args + + + result = asyncio.run(run_gi( + docs=[Document(text=query, id='0')], + entity_types=config.entity_extraction.entity_types, + reporter = None, + pipeline_cache=None, + args=llmm, + )) + + print(result.entities) + exit(0) + +def path3( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + ValueError("Not implemented") + + +def run_local_search( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + paths: int = 0,): + """Run a local search with the given query.""" + if(paths==1): + return path1(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + elif(paths==2): + return path2(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + elif(paths==3): + return path3(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + return path0(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + def blob_exists(container_client, blob_name): blob_client = container_client.get_blob_client(blob_name) try: From 38254efec4ce7b448808728118971260b5051d78 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Mon, 16 Sep 2024 16:21:33 -0700 Subject: [PATCH 71/87] Add missing args --- graphrag/query/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index cbaab1fcb4..2739213bc9 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -117,7 +117,7 @@ def __str__(self): args.query[0], optimized_search=args.optimized_search, use_kusto_community_reports=args.use_kusto_community_reports, - paths=paths, + paths=args.paths, ) case SearchType.GLOBAL: run_global_search( From b5cbe43986dde492a4145f164a39af64bd1cafd4 Mon Sep 17 00:00:00 2001 From: logomachic Date: Tue, 17 Sep 2024 08:42:11 -0700 Subject: [PATCH 72/87] Update __main__.py paths type --- graphrag/query/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index 2739213bc9..e2e01d6fa6 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -92,7 +92,7 @@ def __str__(self): parser.add_argument( "--paths", help="Different paths for the query", - action="store_true", + type=int, default=0, # Default to normal graphrag search ) From 62a3495bcb85c150d746075aadf49ea400e04d89 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Tue, 17 Sep 2024 11:59:05 -0700 Subject: [PATCH 73/87] Add text units to kusto --- graphrag/index/context_switch/contextSwitcher.py | 9 +++++++++ graphrag/vector_stores/azure_ai_search.py | 3 +++ graphrag/vector_stores/base.py | 6 ++++++ graphrag/vector_stores/kusto.py | 16 +++++++++++++++- graphrag/vector_stores/lancedb.py | 5 ++++- 5 files changed, 37 insertions(+), 2 deletions(-) diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index f43c4e53dd..9547fbaa57 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -24,6 +24,7 @@ from graphrag.query.indexer_adapters import ( read_indexer_entities, read_indexer_reports, + read_indexer_text_units, ) from graphrag.model.entity import Entity from azure.cosmos import CosmosClient, PartitionKey @@ -61,12 +62,16 @@ def get_embedding_store(self,config_args): collection_name += "_" + self.context_id config_args.update({"collection_name": collection_name}) + vector_name = config_args.get( "vector_search_column", "description_embedding" ) config_args.update({"vector_name": vector_name}) config_args.update({"reports_name": f"reports_{self.context_id}"}) + + config_args.update({"text_units_name": f"text_units_{self.context_id}"}) + return VectorStoreFactory.get_vector_store( vector_store_type=VectorStoreType.Kusto, kwargs=config_args ) @@ -83,6 +88,8 @@ def setup_vector_store(self, if self.use_kusto_community_reports: description_embedding_store.setup_reports() + description_embedding_store.setup_text_units() + return description_embedding_store def _read_config_parameters(self,root: str, config: str | None): @@ -250,10 +257,12 @@ def _create_graphrag_config( entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. reports = read_indexer_reports(final_community_reports, final_nodes, community_level) + text_units = read_indexer_text_units(final_text_units) description_embedding_store.load_entities(entities) if self.use_kusto_community_reports: description_embedding_store.load_reports(reports) + description_embedding_store.load_text_units(text_units) if config.graphdb.enabled: graph_db_client.write_vertices(final_entities) diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index 7be75d44a5..3986cd3678 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -214,5 +214,8 @@ def setup_entities(self) -> None: def setup_reports(self) -> None: raise NotImplementedError("Setting up reports is not supported for Azure AI Search") + def setup_text_units(self) -> None: + raise NotImplementedError("setup_text_units(): Unsupported for this vector store.") + def unload_entities(self) -> None: raise NotImplementedError("unload_entities(): Unsupported for this vector store.") \ No newline at end of file diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index a968f56a6d..e7398bb6ad 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -47,6 +47,7 @@ def __init__( collection_name: str, vector_name: str, reports_name: str, + text_units_name: str, db_connection: Any | None = None, document_collection: Any | None = None, query_filter: Any | None = None, @@ -55,6 +56,7 @@ def __init__( self.collection_name = collection_name self.vector_name = vector_name self.reports_name = reports_name + self.text_units_name = text_units_name self.db_connection = db_connection self.document_collection = document_collection self.query_filter = query_filter @@ -114,6 +116,10 @@ def setup_entities(self) -> None: def setup_reports(self) -> None: """Setup the reports in the vector-store.""" + @abstractmethod + def setup_text_units(self) -> None: + """Setup the reports in the vector-store.""" + @abstractmethod def unload_entities(self) -> None: """Remove context from the databases.""" \ No newline at end of file diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index 468c5e5428..ae7a54a075 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -260,9 +260,23 @@ def load_reports(self, reports: list[CommunityReport], overwrite: bool = False) self.setup_reports() # Ingest data - ingestion_command = f".ingest inline into table {self.reports_name} <| {df.to_csv(index=False, header=False)}" + ingestion_command = f".ingest inline into table {self.text_units_name} <| {df.to_csv(index=False, header=False)}" self.client.execute(self.database, ingestion_command) + def setup_text_units(self) -> None: + command = f".drop table {self.text_units_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.text_units_name} (id: string, text: string, n_tokens: string, document_ids: string, entity_ids: string, relationship_ids: string)" + self.client.execute(self.database, command) + + + def load_text_units(self, units: list[Entity], overwrite: bool = False) -> None: + df = pd.DataFrame(units) + if overwrite: + self.setup_text_units() + + ingestion_command = f".ingest inline into table {self.text_units_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) def get_extracted_reports( self, community_ids: list[int], **kwargs: Any diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 4bff541703..0599f4f83e 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -141,5 +141,8 @@ def setup_entities(self) -> None: def setup_reports(self) -> None: raise NotImplementedError("Setting up community reports is not supported for LanceDB") + def setup_text_units(self) -> None: + raise NotImplementedError("setup_text_units(): Unsupported for this vector store.") + def unload_entities(self) -> None: - raise NotImplementedError("unload_entities(): Unsupported for this vector store.") \ No newline at end of file + raise NotImplementedError("unload_entities(): Unsupported for this vector store.") From 9a03893b7952c49fd94ea5d21166d5396eb2c589 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Tue, 17 Sep 2024 12:14:33 -0700 Subject: [PATCH 74/87] Text units 2 --- graphrag/index/context_switch/contextSwitcher.py | 5 ++++- graphrag/vector_stores/azure_ai_search.py | 4 ++++ graphrag/vector_stores/base.py | 5 +++++ graphrag/vector_stores/kusto.py | 5 ++++- graphrag/vector_stores/lancedb.py | 4 ++++ 5 files changed, 21 insertions(+), 2 deletions(-) diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index 9547fbaa57..26c007e428 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -20,6 +20,7 @@ ) from graphrag.config.enums import StorageType from graphrag.model.community_report import CommunityReport +from graphrag.model import TextUnit from graphrag.model.entity import Entity from graphrag.query.indexer_adapters import ( read_indexer_entities, @@ -261,7 +262,9 @@ def _create_graphrag_config( description_embedding_store.load_entities(entities) if self.use_kusto_community_reports: - description_embedding_store.load_reports(reports) + raise ValueError("Community reports not supported for kusto.") + #description_embedding_store.load_reports(reports) + description_embedding_store.load_text_units(text_units) if config.graphdb.enabled: diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index 3986cd3678..9a53c9a5b3 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -27,6 +27,7 @@ from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit from .base import ( DEFAULT_VECTOR_SIZE, @@ -205,6 +206,9 @@ def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: raise NotImplementedError("Loading reports is not supported for Azure AI Search") + def load_text_units(self, units: list[TextUnit], overwrite: bool = True) -> None: + raise NotImplementedError("load_text_units(): Unsupported for this vector store.") + def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: raise NotImplementedError("Extracting reports is not supported for Azure AI Search") diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index e7398bb6ad..f143bc77f5 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -10,6 +10,7 @@ from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit DEFAULT_VECTOR_SIZE: int = 1536 @@ -120,6 +121,10 @@ def setup_reports(self) -> None: def setup_text_units(self) -> None: """Setup the reports in the vector-store.""" + @abstractmethod + def load_text_units(self, units: list[TextUnit], overwrite: bool = True) -> None: + """Load reports into the vector-store.""" + @abstractmethod def unload_entities(self) -> None: """Remove context from the databases.""" \ No newline at end of file diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index ae7a54a075..b51dfd2cf6 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -9,6 +9,7 @@ from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit import pandas as pd from pathlib import Path @@ -218,6 +219,8 @@ def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int def unload_entities(self) -> None: self.client.execute(self.database,f".drop table {self.collection_name} ifexists") + self.client.execute(self.database,f".drop table {self.text_units_name} ifexists") + def setup_entities(self) -> None: command = f".drop table {self.collection_name} ifexists" @@ -270,7 +273,7 @@ def setup_text_units(self) -> None: self.client.execute(self.database, command) - def load_text_units(self, units: list[Entity], overwrite: bool = False) -> None: + def load_text_units(self, units: list[TextUnit], overwrite: bool = False) -> None: df = pd.DataFrame(units) if overwrite: self.setup_text_units() diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 0599f4f83e..fb6447d407 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -7,6 +7,7 @@ from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit import json from typing import Any @@ -132,6 +133,9 @@ def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: raise NotImplementedError("Loading reports is not supported for LanceDB") + def load_text_units(self, units: list[TextUnit], overwrite: bool = True) -> None: + raise NotImplementedError("load_text_units(): Unsupported for this vector store.") + def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: raise NotImplementedError("Extracting community reports is not supported for LanceDB") From 951e847523238cf8576566366367c1afa0853708 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Tue, 17 Sep 2024 12:34:59 -0700 Subject: [PATCH 75/87] minor fix --- graphrag/query/cli.py | 2 +- graphrag/vector_stores/kusto.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 3a96c6a823..386be299f7 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -256,7 +256,7 @@ def path0( ######################################################################################## if use_kusto_community_reports: - ValueError("Using community reports is not supported.") + raise ValueError("Using community reports is not supported.") description_embedding_store = __get_embedding_description_store( entities=entities, diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index b51dfd2cf6..273de5ea96 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -220,7 +220,7 @@ def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int def unload_entities(self) -> None: self.client.execute(self.database,f".drop table {self.collection_name} ifexists") self.client.execute(self.database,f".drop table {self.text_units_name} ifexists") - + self.client.execute(self.database,f".drop table {self.reports_name} ifexists") def setup_entities(self) -> None: command = f".drop table {self.collection_name} ifexists" @@ -263,7 +263,7 @@ def load_reports(self, reports: list[CommunityReport], overwrite: bool = False) self.setup_reports() # Ingest data - ingestion_command = f".ingest inline into table {self.text_units_name} <| {df.to_csv(index=False, header=False)}" + ingestion_command = f".ingest inline into table {self.reports_name} <| {df.to_csv(index=False, header=False)}" self.client.execute(self.database, ingestion_command) def setup_text_units(self) -> None: From 175b07574bbee63219da7a1553db43edd6228445 Mon Sep 17 00:00:00 2001 From: Sirus Sh Date: Tue, 17 Sep 2024 19:22:01 -0700 Subject: [PATCH 76/87] Minor fix --- graphrag/query/cli.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 386be299f7..a8a06d69cd 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -68,6 +68,7 @@ def __get_embedding_description_store( ) config_args.update({"vector_name": vector_name}) config_args.update({"reports_name": f"reports_{context_id}" if context_id else "reports"}) + config_args.update({"text_units_name": f"text_units_{context_id}"}) description_embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=vector_store_type, kwargs=config_args @@ -169,8 +170,6 @@ def path0( data_dir, root_dir, config_dir ) - - vector_store_args = ( config.embeddings.vector_store if config.embeddings.vector_store else {} ) From 58d30ded6f7fd6d538d527e74c3269867895f58a Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Thu, 19 Sep 2024 09:32:57 -0700 Subject: [PATCH 77/87] Graphrag using Azure OpenAI uses Managed Identity when no API_KEY present --- graphrag/llm/openai/create_openai_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphrag/llm/openai/create_openai_client.py b/graphrag/llm/openai/create_openai_client.py index 40d7d649d7..cd149323c6 100644 --- a/graphrag/llm/openai/create_openai_client.py +++ b/graphrag/llm/openai/create_openai_client.py @@ -6,7 +6,7 @@ import logging from functools import cache -from azure.identity import DefaultAzureCredential, get_bearer_token_provider +from azure.identity import ManagedIdentityCredential, get_bearer_token_provider from openai import AsyncAzureOpenAI, AsyncOpenAI from .openai_configuration import OpenAIConfiguration @@ -40,7 +40,7 @@ def create_openai_client( return AsyncAzureOpenAI( api_key=configuration.api_key if configuration.api_key else None, azure_ad_token_provider=get_bearer_token_provider( - DefaultAzureCredential(), cognitive_services_endpoint + ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint ) if not configuration.api_key else None, From 1f6a49ce25556d71f222e6a57e9591324834de52 Mon Sep 17 00:00:00 2001 From: Liam Dannaher Date: Thu, 19 Sep 2024 11:54:40 -0700 Subject: [PATCH 78/87] Query & Embedding Manged Identities changes. --- graphrag/query/factories.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index c4320f380a..28caf61bb0 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -5,7 +5,7 @@ from graphrag.config.models.graphdb_config import GraphDBConfig import tiktoken -from azure.identity import DefaultAzureCredential, get_bearer_token_provider +from azure.identity import ManagedIdentityCredential, get_bearer_token_provider from graphrag.config import ( GraphRagConfig, @@ -53,7 +53,7 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI: api_key=config.llm.api_key, azure_ad_token_provider=( get_bearer_token_provider( - DefaultAzureCredential(), cognitive_services_endpoint + ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint ) if is_azure_client and not config.llm.api_key else None @@ -85,7 +85,7 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding: api_key=config.embeddings.llm.api_key, azure_ad_token_provider=( get_bearer_token_provider( - DefaultAzureCredential(), cognitive_services_endpoint + ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint ) if is_azure_client and not config.embeddings.llm.api_key else None From 43806a219c19bc9bbeba4fff45daccec33a0b752 Mon Sep 17 00:00:00 2001 From: Prateek Jain Date: Sun, 22 Sep 2024 22:13:32 -0700 Subject: [PATCH 79/87] Added the func app compatible code --- func-app/.gitignore | 49 ++ func-app/.vscode/launch.json | 13 + func-app/.vscode/settings.json | 8 + func-app/.vscode/tasks.json | 26 + func-app/common/graph_db_client.py | 158 ++++ func-app/function_app.py | 37 + func-app/graphrag/__init__.py | 4 + .../graphrag/common/blob_storage_client.py | 58 ++ func-app/graphrag/common/config/storage.py | 72 ++ func-app/graphrag/common/graph_db_client.py | 1 + func-app/graphrag/common/kusto_db_client.py | 1 + func-app/graphrag/common/progress/__init__.py | 8 + func-app/graphrag/common/progress/rich.py | 165 +++++ func-app/graphrag/common/progress/types.py | 128 ++++ func-app/graphrag/common/storage/__init__.py | 19 + .../common/storage/blob_pipeline_storage.py | 375 ++++++++++ .../common/storage/file_pipeline_storage.py | 166 +++++ .../graphrag/common/storage/load_storage.py | 40 + .../common/storage/memory_pipeline_storage.py | 79 ++ func-app/graphrag/common/storage/typing.py | 80 ++ .../graphrag/common/utils/common_utils.py | 11 + .../graphrag/common/utils/context_utils.py | 9 + func-app/graphrag/config/__init__.py | 128 ++++ .../graphrag/config/create_graphrag_config.py | 687 ++++++++++++++++++ func-app/graphrag/config/defaults.py | 106 +++ func-app/graphrag/config/enums.py | 127 ++++ .../graphrag/config/environment_reader.py | 155 ++++ func-app/graphrag/config/errors.py | 40 + .../graphrag/config/input_models/__init__.py | 50 ++ .../config/input_models/cache_config_input.py | 18 + .../input_models/chunking_config_input.py | 15 + .../claim_extraction_config_input.py | 19 + .../cluster_graph_config_input.py | 13 + .../community_reports_config_input.py | 17 + .../input_models/embed_graph_config_input.py | 18 + .../entity_extraction_config_input.py | 18 + .../global_search_config_input.py | 16 + .../input_models/graphrag_config_input.py | 49 ++ .../config/input_models/input_config_input.py | 27 + .../config/input_models/llm_config_input.py | 18 + .../input_models/llm_parameters_input.py | 31 + .../input_models/local_search_config_input.py | 18 + .../parallelization_parameters_input.py | 13 + .../query_context_config_input.py | 7 + .../input_models/reporting_config_input.py | 18 + .../input_models/snapshots_config_input.py | 14 + .../input_models/storage_config_input.py | 18 + .../summarize_descriptions_config_input.py | 16 + .../text_embedding_config_input.py | 23 + .../config/input_models/umap_config_input.py | 12 + func-app/graphrag/config/models/__init__.py | 52 ++ .../graphrag/config/models/cache_config.py | 29 + .../graphrag/config/models/chunking_config.py | 40 + .../config/models/claim_extraction_config.py | 57 ++ .../config/models/cluster_graph_config.py | 28 + .../config/models/community_reports_config.py | 48 ++ .../config/models/embed_graph_config.py | 48 ++ .../config/models/entity_extraction_config.py | 53 ++ .../config/models/global_search_config.py | 45 ++ .../config/models/graph_rag_config.py | 158 ++++ .../graphrag/config/models/graphdb_config.py | 37 + .../graphrag/config/models/input_config.py | 60 ++ func-app/graphrag/config/models/llm_config.py | 27 + .../graphrag/config/models/llm_parameters.py | 87 +++ .../config/models/local_search_config.py | 51 ++ .../models/parallelization_parameters.py | 21 + .../config/models/query_context_config.py | 16 + .../config/models/reporting_config.py | 30 + .../config/models/snapshots_config.py | 25 + .../graphrag/config/models/storage_config.py | 33 + .../models/summarize_descriptions_config.py | 43 ++ .../config/models/text_embedding_config.py | 46 ++ .../graphrag/config/models/umap_config.py | 17 + func-app/graphrag/config/read_dotenv.py | 25 + func-app/graphrag/index/__init__.py | 78 ++ func-app/graphrag/index/__main__.py | 125 ++++ func-app/graphrag/index/bootstrap.py | 28 + func-app/graphrag/index/cache/__init__.py | 18 + .../index/cache/json_pipeline_cache.py | 64 ++ func-app/graphrag/index/cache/load_cache.py | 51 ++ .../index/cache/memory_pipeline_cache.py | 83 +++ .../index/cache/noop_pipeline_cache.py | 65 ++ .../graphrag/index/cache/pipeline_cache.py | 67 ++ func-app/graphrag/index/cli.py | 356 +++++++++ func-app/graphrag/index/config/__init__.py | 69 ++ func-app/graphrag/index/config/cache.py | 82 +++ func-app/graphrag/index/config/input.py | 120 +++ func-app/graphrag/index/config/pipeline.py | 70 ++ func-app/graphrag/index/config/reporting.py | 77 ++ func-app/graphrag/index/config/workflow.py | 34 + func-app/graphrag/index/context.py | 42 ++ .../index/context_switch/contextSwitcher.py | 288 ++++++++ .../graphrag/index/create_pipeline_config.py | 595 +++++++++++++++ func-app/graphrag/index/emit/__init__.py | 21 + .../graphrag/index/emit/csv_table_emitter.py | 33 + func-app/graphrag/index/emit/factories.py | 46 ++ .../graphrag/index/emit/graph_db_emitter.py | 24 + .../graphrag/index/emit/json_table_emitter.py | 34 + .../index/emit/parquet_table_emitter.py | 54 ++ func-app/graphrag/index/emit/table_emitter.py | 15 + func-app/graphrag/index/emit/types.py | 15 + func-app/graphrag/index/errors.py | 25 + func-app/graphrag/index/graph/__init__.py | 4 + .../index/graph/embedding/__init__.py | 8 + .../index/graph/embedding/embedding.py | 41 ++ .../index/graph/extractors/__init__.py | 20 + .../index/graph/extractors/claims/__init__.py | 9 + .../extractors/claims/claim_extractor.py | 248 +++++++ .../index/graph/extractors/claims/prompts.py | 61 ++ .../extractors/community_reports/__init__.py | 35 + .../community_reports/build_mixed_context.py | 69 ++ .../community_reports_extractor.py | 107 +++ .../prep_community_report_context.py | 181 +++++ .../extractors/community_reports/prompts.py | 150 ++++ .../extractors/community_reports/schemas.py | 52 ++ .../community_reports/sort_context.py | 156 ++++ .../extractors/community_reports/utils.py | 53 ++ .../index/graph/extractors/graph/__init__.py | 18 + .../graph/extractors/graph/graph_extractor.py | 305 ++++++++ .../index/graph/extractors/graph/prompts.py | 129 ++++ .../graph/extractors/summarize/__init__.py | 12 + .../description_summary_extractor.py | 135 ++++ .../graph/extractors/summarize/prompts.py | 19 + .../graphrag/index/graph/utils/__init__.py | 9 + .../index/graph/utils/normalize_node_names.py | 14 + .../graphrag/index/graph/utils/stable_lcc.py | 60 ++ .../index/graph/visualization/__init__.py | 14 + .../visualization/compute_umap_positions.py | 144 ++++ .../index/graph/visualization/typing.py | 27 + func-app/graphrag/index/init_content.py | 176 +++++ func-app/graphrag/index/input/__init__.py | 8 + func-app/graphrag/index/input/csv.py | 138 ++++ func-app/graphrag/index/input/load_input.py | 84 +++ func-app/graphrag/index/input/text.py | 72 ++ func-app/graphrag/index/llm/__init__.py | 14 + func-app/graphrag/index/llm/load_llm.py | 313 ++++++++ func-app/graphrag/index/llm/types.py | 10 + .../graphrag/index/load_pipeline_config.py | 80 ++ func-app/graphrag/index/py.typed | 2 + func-app/graphrag/index/reporting/__init__.py | 18 + .../reporting/blob_workflow_callbacks.py | 108 +++ .../reporting/console_workflow_callbacks.py | 32 + .../reporting/file_workflow_callbacks.py | 67 ++ .../index/reporting/load_pipeline_reporter.py | 47 ++ .../reporting/progress_workflow_callbacks.py | 54 ++ func-app/graphrag/index/run.py | 471 ++++++++++++ .../graphrag/index/text_splitting/__init__.py | 34 + .../index/text_splitting/check_token_limit.py | 15 + .../index/text_splitting/text_splitting.py | 244 +++++++ func-app/graphrag/index/typing.py | 20 + func-app/graphrag/index/utils/__init__.py | 25 + func-app/graphrag/index/utils/dataframes.py | 61 ++ func-app/graphrag/index/utils/dicts.py | 18 + func-app/graphrag/index/utils/ds_util.py | 32 + func-app/graphrag/index/utils/hashing.py | 14 + func-app/graphrag/index/utils/is_null.py | 19 + func-app/graphrag/index/utils/load_graph.py | 11 + func-app/graphrag/index/utils/rate_limiter.py | 40 + func-app/graphrag/index/utils/string.py | 19 + func-app/graphrag/index/utils/tokens.py | 41 ++ .../graphrag/index/utils/topological_sort.py | 12 + func-app/graphrag/index/utils/uuid.py | 14 + func-app/graphrag/index/verbs/__init__.py | 50 ++ .../index/verbs/covariates/__init__.py | 8 + .../covariates/extract_covariates/__init__.py | 8 + .../extract_covariates/extract_covariates.py | 110 +++ .../extract_covariates/strategies/__init__.py | 4 + .../strategies/graph_intelligence/__init__.py | 8 + .../strategies/graph_intelligence/defaults.py | 21 + .../run_gi_extract_claims.py | 106 +++ .../graphrag/index/verbs/covariates/typing.py | 52 ++ .../graphrag/index/verbs/entities/__init__.py | 9 + .../verbs/entities/extraction/__init__.py | 8 + .../entities/extraction/entity_extract.py | 202 +++++ .../extraction/strategies/__init__.py | 4 + .../strategies/graph_intelligence/__init__.py | 8 + .../strategies/graph_intelligence/defaults.py | 25 + .../run_graph_intelligence.py | 142 ++++ .../entities/extraction/strategies/nltk.py | 61 ++ .../entities/extraction/strategies/typing.py | 44 ++ .../verbs/entities/summarize/__init__.py | 8 + .../summarize/description_summarize.py | 207 ++++++ .../entities/summarize/strategies/__init__.py | 8 + .../strategies/graph_intelligence/__init__.py | 8 + .../strategies/graph_intelligence/defaults.py | 17 + .../run_graph_intelligence.py | 70 ++ .../entities/summarize/strategies/typing.py | 34 + func-app/graphrag/index/verbs/genid.py | 66 ++ .../graphrag/index/verbs/graph/__init__.py | 36 + .../index/verbs/graph/clustering/__init__.py | 8 + .../verbs/graph/clustering/cluster_graph.py | 182 +++++ .../graph/clustering/strategies/__init__.py | 4 + .../graph/clustering/strategies/leiden.py | 69 ++ .../index/verbs/graph/clustering/typing.py | 6 + .../graph/compute_edge_combined_degree.py | 70 ++ func-app/graphrag/index/verbs/graph/create.py | 135 ++++ .../index/verbs/graph/embed/__init__.py | 8 + .../index/verbs/graph/embed/embed_graph.py | 98 +++ .../verbs/graph/embed/strategies/__init__.py | 4 + .../graph/embed/strategies/node_2_vec.py | 34 + .../index/verbs/graph/embed/typing.py | 12 + .../index/verbs/graph/layout/__init__.py | 8 + .../index/verbs/graph/layout/layout_graph.py | 139 ++++ .../verbs/graph/layout/methods/__init__.py | 4 + .../index/verbs/graph/layout/methods/umap.py | 82 +++ .../index/verbs/graph/layout/methods/zero.py | 63 ++ .../index/verbs/graph/merge/__init__.py | 8 + .../index/verbs/graph/merge/defaults.py | 21 + .../index/verbs/graph/merge/merge_graphs.py | 217 ++++++ .../index/verbs/graph/merge/typing.py | 49 ++ .../index/verbs/graph/report/__init__.py | 25 + .../graph/report/create_community_reports.py | 131 ++++ .../graph/report/prepare_community_reports.py | 187 +++++ .../prepare_community_reports_claims.py | 50 ++ .../report/prepare_community_reports_edges.py | 48 ++ .../report/prepare_community_reports_nodes.py | 46 ++ .../report/restore_community_hierarchy.py | 78 ++ .../verbs/graph/report/strategies/__init__.py | 4 + .../strategies/graph_intelligence/__init__.py | 8 + .../strategies/graph_intelligence/defaults.py | 26 + .../run_graph_intelligence.py | 99 +++ .../verbs/graph/report/strategies/typing.py | 52 ++ func-app/graphrag/index/verbs/graph/unpack.py | 107 +++ .../index/verbs/overrides/__init__.py | 10 + .../index/verbs/overrides/aggregate.py | 90 +++ .../graphrag/index/verbs/overrides/concat.py | 27 + .../graphrag/index/verbs/overrides/merge.py | 78 ++ func-app/graphrag/index/verbs/snapshot.py | 30 + .../graphrag/index/verbs/snapshot_rows.py | 86 +++ func-app/graphrag/index/verbs/spread_json.py | 55 ++ .../graphrag/index/verbs/text/__init__.py | 18 + .../index/verbs/text/chunk/__init__.py | 8 + .../verbs/text/chunk/strategies/__init__.py | 4 + .../verbs/text/chunk/strategies/sentence.py | 26 + .../verbs/text/chunk/strategies/tokens.py | 81 +++ .../verbs/text/chunk/strategies/typing.py | 17 + .../index/verbs/text/chunk/text_chunk.py | 162 +++++ .../graphrag/index/verbs/text/chunk/typing.py | 19 + .../index/verbs/text/embed/__init__.py | 8 + .../verbs/text/embed/strategies/__init__.py | 4 + .../index/verbs/text/embed/strategies/mock.py | 34 + .../verbs/text/embed/strategies/openai.py | 181 +++++ .../verbs/text/embed/strategies/typing.py | 29 + .../index/verbs/text/embed/text_embed.py | 269 +++++++ .../index/verbs/text/replace/__init__.py | 8 + .../index/verbs/text/replace/replace.py | 47 ++ .../index/verbs/text/replace/typing.py | 14 + func-app/graphrag/index/verbs/text/split.py | 54 ++ .../index/verbs/text/translate/__init__.py | 8 + .../text/translate/strategies/__init__.py | 9 + .../text/translate/strategies/defaults.py | 8 + .../verbs/text/translate/strategies/mock.py | 28 + .../verbs/text/translate/strategies/openai.py | 93 +++ .../verbs/text/translate/strategies/typing.py | 25 + .../verbs/text/translate/text_translate.py | 120 +++ func-app/graphrag/index/verbs/unzip.py | 25 + func-app/graphrag/index/verbs/zip.py | 51 ++ func-app/graphrag/index/workflows/__init__.py | 25 + .../index/workflows/default_workflows.py | 121 +++ func-app/graphrag/index/workflows/load.py | 171 +++++ func-app/graphrag/index/workflows/typing.py | 33 + .../graphrag/index/workflows/v1/__init__.py | 4 + .../workflows/v1/create_base_documents.py | 105 +++ .../workflows/v1/create_base_entity_graph.py | 91 +++ .../v1/create_base_extracted_entities.py | 95 +++ .../workflows/v1/create_base_text_units.py | 112 +++ .../workflows/v1/create_final_communities.py | 172 +++++ .../v1/create_final_community_reports.py | 133 ++++ .../workflows/v1/create_final_covariates.py | 90 +++ .../workflows/v1/create_final_documents.py | 41 ++ .../workflows/v1/create_final_entities.py | 133 ++++ .../index/workflows/v1/create_final_nodes.py | 116 +++ .../v1/create_final_relationships.py | 94 +++ .../workflows/v1/create_final_text_units.py | 161 ++++ .../v1/create_summarized_entities.py | 47 ++ .../v1/join_text_units_to_covariate_ids.py | 44 ++ .../v1/join_text_units_to_entity_ids.py | 50 ++ .../v1/join_text_units_to_relationship_ids.py | 55 ++ func-app/graphrag/llm/__init__.py | 91 +++ func-app/graphrag/llm/base/__init__.py | 10 + .../graphrag/llm/base/_create_cache_key.py | 43 ++ func-app/graphrag/llm/base/base_llm.py | 65 ++ func-app/graphrag/llm/base/caching_llm.py | 109 +++ .../graphrag/llm/base/rate_limiting_llm.py | 208 ++++++ func-app/graphrag/llm/errors.py | 12 + func-app/graphrag/llm/limiting/__init__.py | 18 + .../llm/limiting/composite_limiter.py | 26 + .../graphrag/llm/limiting/create_limiters.py | 29 + func-app/graphrag/llm/limiting/llm_limiter.py | 19 + .../graphrag/llm/limiting/noop_llm_limiter.py | 19 + .../graphrag/llm/limiting/tpm_rpm_limiter.py | 34 + func-app/graphrag/llm/mock/__init__.py | 12 + func-app/graphrag/llm/mock/mock_chat_llm.py | 52 ++ .../graphrag/llm/mock/mock_completion_llm.py | 37 + func-app/graphrag/llm/openai/__init__.py | 28 + func-app/graphrag/llm/openai/_prompts.py | 39 + .../llm/openai/create_openai_client.py | 65 ++ func-app/graphrag/llm/openai/factories.py | 140 ++++ .../graphrag/llm/openai/json_parsing_llm.py | 38 + .../graphrag/llm/openai/openai_chat_llm.py | 148 ++++ .../llm/openai/openai_completion_llm.py | 43 ++ .../llm/openai/openai_configuration.py | 288 ++++++++ .../llm/openai/openai_embeddings_llm.py | 40 + .../llm/openai/openai_history_tracking_llm.py | 42 ++ .../llm/openai/openai_token_replacing_llm.py | 37 + func-app/graphrag/llm/openai/types.py | 11 + func-app/graphrag/llm/openai/utils.py | 160 ++++ func-app/graphrag/llm/types/__init__.py | 46 ++ func-app/graphrag/llm/types/llm.py | 28 + func-app/graphrag/llm/types/llm_cache.py | 22 + func-app/graphrag/llm/types/llm_callbacks.py | 20 + func-app/graphrag/llm/types/llm_config.py | 35 + .../llm/types/llm_invocation_result.py | 35 + func-app/graphrag/llm/types/llm_io.py | 50 ++ func-app/graphrag/llm/types/llm_types.py | 16 + func-app/graphrag/model/__init__.py | 31 + func-app/graphrag/model/community.py | 54 ++ func-app/graphrag/model/community_report.py | 64 ++ func-app/graphrag/model/covariate.py | 61 ++ func-app/graphrag/model/document.py | 64 ++ func-app/graphrag/model/entity.py | 79 ++ func-app/graphrag/model/identified.py | 17 + func-app/graphrag/model/named.py | 16 + func-app/graphrag/model/relationship.py | 65 ++ func-app/graphrag/model/text_unit.py | 67 ++ func-app/graphrag/model/types.py | 8 + func-app/graphrag/prompt_tune/__init__.py | 4 + func-app/graphrag/prompt_tune/__main__.py | 148 ++++ func-app/graphrag/prompt_tune/cli.py | 272 +++++++ .../prompt_tune/generator/__init__.py | 30 + .../generator/community_report_rating.py | 35 + .../community_report_summarization.py | 48 ++ .../generator/community_reporter_role.py | 35 + .../prompt_tune/generator/defaults.py | 10 + .../graphrag/prompt_tune/generator/domain.py | 27 + .../generator/entity_extraction_prompt.py | 107 +++ .../generator/entity_relationship.py | 65 ++ .../generator/entity_summarization_prompt.py | 36 + .../prompt_tune/generator/entity_types.py | 45 ++ .../prompt_tune/generator/language.py | 27 + .../graphrag/prompt_tune/generator/persona.py | 27 + .../graphrag/prompt_tune/loader/__init__.py | 14 + .../graphrag/prompt_tune/loader/config.py | 43 ++ func-app/graphrag/prompt_tune/loader/input.py | 110 +++ .../graphrag/prompt_tune/prompt/__init__.py | 32 + .../prompt/community_report_rating.py | 132 ++++ .../prompt/community_reporter_role.py | 20 + .../graphrag/prompt_tune/prompt/domain.py | 12 + .../prompt_tune/prompt/entity_relationship.py | 355 +++++++++ .../prompt_tune/prompt/entity_types.py | 89 +++ .../graphrag/prompt_tune/prompt/language.py | 12 + .../graphrag/prompt_tune/prompt/persona.py | 13 + .../graphrag/prompt_tune/template/__init__.py | 24 + .../community_report_summarization.py | 95 +++ .../prompt_tune/template/entity_extraction.py | 141 ++++ .../template/entity_summarization.py | 22 + func-app/graphrag/query/__init__.py | 4 + func-app/graphrag/query/__main__.py | 133 ++++ func-app/graphrag/query/cli.py | 472 ++++++++++++ .../query/context_builder/__init__.py | 4 + .../query/context_builder/builders.py | 35 + .../context_builder/community_context.py | 253 +++++++ .../context_builder/conversation_history.py | 212 ++++++ .../context_builder/entity_extraction.py | 187 +++++ .../query/context_builder/local_context.py | 360 +++++++++ .../query/context_builder/source_context.py | 110 +++ func-app/graphrag/query/factories.py | 211 ++++++ func-app/graphrag/query/indexer_adapters.py | 159 ++++ func-app/graphrag/query/input/__init__.py | 4 + .../graphrag/query/input/loaders/__init__.py | 4 + func-app/graphrag/query/input/loaders/dfs.py | 340 +++++++++ .../graphrag/query/input/loaders/utils.py | 245 +++++++ .../query/input/retrieval/__init__.py | 4 + .../input/retrieval/community_reports.py | 74 ++ .../query/input/retrieval/covariates.py | 52 ++ .../query/input/retrieval/entities.py | 93 +++ .../query/input/retrieval/relationships.py | 217 ++++++ .../query/input/retrieval/text_units.py | 52 ++ func-app/graphrag/query/llm/__init__.py | 4 + func-app/graphrag/query/llm/base.py | 54 ++ func-app/graphrag/query/llm/oai/__init__.py | 21 + func-app/graphrag/query/llm/oai/base.py | 187 +++++ .../graphrag/query/llm/oai/chat_openai.py | 206 ++++++ func-app/graphrag/query/llm/oai/embedding.py | 182 +++++ func-app/graphrag/query/llm/oai/openai.py | 187 +++++ func-app/graphrag/query/llm/oai/typing.py | 23 + func-app/graphrag/query/llm/text_utils.py | 42 ++ func-app/graphrag/query/progress.py | 43 ++ .../graphrag/query/question_gen/__init__.py | 4 + func-app/graphrag/query/question_gen/base.py | 65 ++ .../graphrag/query/question_gen/local_gen.py | 194 +++++ .../query/question_gen/system_prompt.py | 28 + .../query/structured_search/__init__.py | 4 + .../graphrag/query/structured_search/base.py | 69 ++ .../global_search/__init__.py | 4 + .../global_search/callbacks.py | 24 + .../global_search/community_context.py | 99 +++ .../global_search/map_system_prompt.py | 82 +++ .../global_search/reduce_system_prompt.py | 88 +++ .../structured_search/global_search/search.py | 359 +++++++++ .../local_search/__init__.py | 4 + .../local_search/mixed_context.py | 533 ++++++++++++++ .../structured_search/local_search/search.py | 199 +++++ .../local_search/system_prompt.py | 69 ++ func-app/graphrag/vector_stores/__init__.py | 19 + .../graphrag/vector_stores/azure_ai_search.py | 225 ++++++ func-app/graphrag/vector_stores/base.py | 130 ++++ func-app/graphrag/vector_stores/kusto.py | 308 ++++++++ func-app/graphrag/vector_stores/lancedb.py | 152 ++++ func-app/graphrag/vector_stores/typing.py | 48 ++ func-app/host.json | 15 + func-app/prompts/claim_extraction.txt | 52 ++ func-app/prompts/community_report.txt | 146 ++++ func-app/prompts/entity_extraction.txt | 99 +++ func-app/prompts/summarize_descriptions.txt | 13 + func-app/requirements.txt | 142 ++++ func-app/settings.yaml | 165 +++++ func-app/settings/settings.yaml | 152 ++++ 418 files changed, 30607 insertions(+) create mode 100644 func-app/.gitignore create mode 100644 func-app/.vscode/launch.json create mode 100644 func-app/.vscode/settings.json create mode 100644 func-app/.vscode/tasks.json create mode 100644 func-app/common/graph_db_client.py create mode 100644 func-app/function_app.py create mode 100644 func-app/graphrag/__init__.py create mode 100644 func-app/graphrag/common/blob_storage_client.py create mode 100644 func-app/graphrag/common/config/storage.py create mode 100644 func-app/graphrag/common/graph_db_client.py create mode 100644 func-app/graphrag/common/kusto_db_client.py create mode 100644 func-app/graphrag/common/progress/__init__.py create mode 100644 func-app/graphrag/common/progress/rich.py create mode 100644 func-app/graphrag/common/progress/types.py create mode 100644 func-app/graphrag/common/storage/__init__.py create mode 100644 func-app/graphrag/common/storage/blob_pipeline_storage.py create mode 100644 func-app/graphrag/common/storage/file_pipeline_storage.py create mode 100644 func-app/graphrag/common/storage/load_storage.py create mode 100644 func-app/graphrag/common/storage/memory_pipeline_storage.py create mode 100644 func-app/graphrag/common/storage/typing.py create mode 100644 func-app/graphrag/common/utils/common_utils.py create mode 100644 func-app/graphrag/common/utils/context_utils.py create mode 100644 func-app/graphrag/config/__init__.py create mode 100644 func-app/graphrag/config/create_graphrag_config.py create mode 100644 func-app/graphrag/config/defaults.py create mode 100644 func-app/graphrag/config/enums.py create mode 100644 func-app/graphrag/config/environment_reader.py create mode 100644 func-app/graphrag/config/errors.py create mode 100644 func-app/graphrag/config/input_models/__init__.py create mode 100644 func-app/graphrag/config/input_models/cache_config_input.py create mode 100644 func-app/graphrag/config/input_models/chunking_config_input.py create mode 100644 func-app/graphrag/config/input_models/claim_extraction_config_input.py create mode 100644 func-app/graphrag/config/input_models/cluster_graph_config_input.py create mode 100644 func-app/graphrag/config/input_models/community_reports_config_input.py create mode 100644 func-app/graphrag/config/input_models/embed_graph_config_input.py create mode 100644 func-app/graphrag/config/input_models/entity_extraction_config_input.py create mode 100644 func-app/graphrag/config/input_models/global_search_config_input.py create mode 100644 func-app/graphrag/config/input_models/graphrag_config_input.py create mode 100644 func-app/graphrag/config/input_models/input_config_input.py create mode 100644 func-app/graphrag/config/input_models/llm_config_input.py create mode 100644 func-app/graphrag/config/input_models/llm_parameters_input.py create mode 100644 func-app/graphrag/config/input_models/local_search_config_input.py create mode 100644 func-app/graphrag/config/input_models/parallelization_parameters_input.py create mode 100644 func-app/graphrag/config/input_models/query_context_config_input.py create mode 100644 func-app/graphrag/config/input_models/reporting_config_input.py create mode 100644 func-app/graphrag/config/input_models/snapshots_config_input.py create mode 100644 func-app/graphrag/config/input_models/storage_config_input.py create mode 100644 func-app/graphrag/config/input_models/summarize_descriptions_config_input.py create mode 100644 func-app/graphrag/config/input_models/text_embedding_config_input.py create mode 100644 func-app/graphrag/config/input_models/umap_config_input.py create mode 100644 func-app/graphrag/config/models/__init__.py create mode 100644 func-app/graphrag/config/models/cache_config.py create mode 100644 func-app/graphrag/config/models/chunking_config.py create mode 100644 func-app/graphrag/config/models/claim_extraction_config.py create mode 100644 func-app/graphrag/config/models/cluster_graph_config.py create mode 100644 func-app/graphrag/config/models/community_reports_config.py create mode 100644 func-app/graphrag/config/models/embed_graph_config.py create mode 100644 func-app/graphrag/config/models/entity_extraction_config.py create mode 100644 func-app/graphrag/config/models/global_search_config.py create mode 100644 func-app/graphrag/config/models/graph_rag_config.py create mode 100644 func-app/graphrag/config/models/graphdb_config.py create mode 100644 func-app/graphrag/config/models/input_config.py create mode 100644 func-app/graphrag/config/models/llm_config.py create mode 100644 func-app/graphrag/config/models/llm_parameters.py create mode 100644 func-app/graphrag/config/models/local_search_config.py create mode 100644 func-app/graphrag/config/models/parallelization_parameters.py create mode 100644 func-app/graphrag/config/models/query_context_config.py create mode 100644 func-app/graphrag/config/models/reporting_config.py create mode 100644 func-app/graphrag/config/models/snapshots_config.py create mode 100644 func-app/graphrag/config/models/storage_config.py create mode 100644 func-app/graphrag/config/models/summarize_descriptions_config.py create mode 100644 func-app/graphrag/config/models/text_embedding_config.py create mode 100644 func-app/graphrag/config/models/umap_config.py create mode 100644 func-app/graphrag/config/read_dotenv.py create mode 100644 func-app/graphrag/index/__init__.py create mode 100644 func-app/graphrag/index/__main__.py create mode 100644 func-app/graphrag/index/bootstrap.py create mode 100644 func-app/graphrag/index/cache/__init__.py create mode 100644 func-app/graphrag/index/cache/json_pipeline_cache.py create mode 100644 func-app/graphrag/index/cache/load_cache.py create mode 100644 func-app/graphrag/index/cache/memory_pipeline_cache.py create mode 100644 func-app/graphrag/index/cache/noop_pipeline_cache.py create mode 100644 func-app/graphrag/index/cache/pipeline_cache.py create mode 100644 func-app/graphrag/index/cli.py create mode 100644 func-app/graphrag/index/config/__init__.py create mode 100644 func-app/graphrag/index/config/cache.py create mode 100644 func-app/graphrag/index/config/input.py create mode 100644 func-app/graphrag/index/config/pipeline.py create mode 100644 func-app/graphrag/index/config/reporting.py create mode 100644 func-app/graphrag/index/config/workflow.py create mode 100644 func-app/graphrag/index/context.py create mode 100644 func-app/graphrag/index/context_switch/contextSwitcher.py create mode 100644 func-app/graphrag/index/create_pipeline_config.py create mode 100644 func-app/graphrag/index/emit/__init__.py create mode 100644 func-app/graphrag/index/emit/csv_table_emitter.py create mode 100644 func-app/graphrag/index/emit/factories.py create mode 100644 func-app/graphrag/index/emit/graph_db_emitter.py create mode 100644 func-app/graphrag/index/emit/json_table_emitter.py create mode 100644 func-app/graphrag/index/emit/parquet_table_emitter.py create mode 100644 func-app/graphrag/index/emit/table_emitter.py create mode 100644 func-app/graphrag/index/emit/types.py create mode 100644 func-app/graphrag/index/errors.py create mode 100644 func-app/graphrag/index/graph/__init__.py create mode 100644 func-app/graphrag/index/graph/embedding/__init__.py create mode 100644 func-app/graphrag/index/graph/embedding/embedding.py create mode 100644 func-app/graphrag/index/graph/extractors/__init__.py create mode 100644 func-app/graphrag/index/graph/extractors/claims/__init__.py create mode 100644 func-app/graphrag/index/graph/extractors/claims/claim_extractor.py create mode 100644 func-app/graphrag/index/graph/extractors/claims/prompts.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/__init__.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/build_mixed_context.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/prompts.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/schemas.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/sort_context.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/utils.py create mode 100644 func-app/graphrag/index/graph/extractors/graph/__init__.py create mode 100644 func-app/graphrag/index/graph/extractors/graph/graph_extractor.py create mode 100644 func-app/graphrag/index/graph/extractors/graph/prompts.py create mode 100644 func-app/graphrag/index/graph/extractors/summarize/__init__.py create mode 100644 func-app/graphrag/index/graph/extractors/summarize/description_summary_extractor.py create mode 100644 func-app/graphrag/index/graph/extractors/summarize/prompts.py create mode 100644 func-app/graphrag/index/graph/utils/__init__.py create mode 100644 func-app/graphrag/index/graph/utils/normalize_node_names.py create mode 100644 func-app/graphrag/index/graph/utils/stable_lcc.py create mode 100644 func-app/graphrag/index/graph/visualization/__init__.py create mode 100644 func-app/graphrag/index/graph/visualization/compute_umap_positions.py create mode 100644 func-app/graphrag/index/graph/visualization/typing.py create mode 100644 func-app/graphrag/index/init_content.py create mode 100644 func-app/graphrag/index/input/__init__.py create mode 100644 func-app/graphrag/index/input/csv.py create mode 100644 func-app/graphrag/index/input/load_input.py create mode 100644 func-app/graphrag/index/input/text.py create mode 100644 func-app/graphrag/index/llm/__init__.py create mode 100644 func-app/graphrag/index/llm/load_llm.py create mode 100644 func-app/graphrag/index/llm/types.py create mode 100644 func-app/graphrag/index/load_pipeline_config.py create mode 100644 func-app/graphrag/index/py.typed create mode 100644 func-app/graphrag/index/reporting/__init__.py create mode 100644 func-app/graphrag/index/reporting/blob_workflow_callbacks.py create mode 100644 func-app/graphrag/index/reporting/console_workflow_callbacks.py create mode 100644 func-app/graphrag/index/reporting/file_workflow_callbacks.py create mode 100644 func-app/graphrag/index/reporting/load_pipeline_reporter.py create mode 100644 func-app/graphrag/index/reporting/progress_workflow_callbacks.py create mode 100644 func-app/graphrag/index/run.py create mode 100644 func-app/graphrag/index/text_splitting/__init__.py create mode 100644 func-app/graphrag/index/text_splitting/check_token_limit.py create mode 100644 func-app/graphrag/index/text_splitting/text_splitting.py create mode 100644 func-app/graphrag/index/typing.py create mode 100644 func-app/graphrag/index/utils/__init__.py create mode 100644 func-app/graphrag/index/utils/dataframes.py create mode 100644 func-app/graphrag/index/utils/dicts.py create mode 100644 func-app/graphrag/index/utils/ds_util.py create mode 100644 func-app/graphrag/index/utils/hashing.py create mode 100644 func-app/graphrag/index/utils/is_null.py create mode 100644 func-app/graphrag/index/utils/load_graph.py create mode 100644 func-app/graphrag/index/utils/rate_limiter.py create mode 100644 func-app/graphrag/index/utils/string.py create mode 100644 func-app/graphrag/index/utils/tokens.py create mode 100644 func-app/graphrag/index/utils/topological_sort.py create mode 100644 func-app/graphrag/index/utils/uuid.py create mode 100644 func-app/graphrag/index/verbs/__init__.py create mode 100644 func-app/graphrag/index/verbs/covariates/__init__.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/__init__.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/__init__.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/defaults.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/run_gi_extract_claims.py create mode 100644 func-app/graphrag/index/verbs/covariates/typing.py create mode 100644 func-app/graphrag/index/verbs/entities/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/entity_extract.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/defaults.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/run_graph_intelligence.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/nltk.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/description_summarize.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/defaults.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/run_graph_intelligence.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/genid.py create mode 100644 func-app/graphrag/index/verbs/graph/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/clustering/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/clustering/cluster_graph.py create mode 100644 func-app/graphrag/index/verbs/graph/clustering/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/clustering/strategies/leiden.py create mode 100644 func-app/graphrag/index/verbs/graph/clustering/typing.py create mode 100644 func-app/graphrag/index/verbs/graph/compute_edge_combined_degree.py create mode 100644 func-app/graphrag/index/verbs/graph/create.py create mode 100644 func-app/graphrag/index/verbs/graph/embed/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/embed/embed_graph.py create mode 100644 func-app/graphrag/index/verbs/graph/embed/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/embed/strategies/node_2_vec.py create mode 100644 func-app/graphrag/index/verbs/graph/embed/typing.py create mode 100644 func-app/graphrag/index/verbs/graph/layout/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/layout/layout_graph.py create mode 100644 func-app/graphrag/index/verbs/graph/layout/methods/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/layout/methods/umap.py create mode 100644 func-app/graphrag/index/verbs/graph/layout/methods/zero.py create mode 100644 func-app/graphrag/index/verbs/graph/merge/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/merge/defaults.py create mode 100644 func-app/graphrag/index/verbs/graph/merge/merge_graphs.py create mode 100644 func-app/graphrag/index/verbs/graph/merge/typing.py create mode 100644 func-app/graphrag/index/verbs/graph/report/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/report/create_community_reports.py create mode 100644 func-app/graphrag/index/verbs/graph/report/prepare_community_reports.py create mode 100644 func-app/graphrag/index/verbs/graph/report/prepare_community_reports_claims.py create mode 100644 func-app/graphrag/index/verbs/graph/report/prepare_community_reports_edges.py create mode 100644 func-app/graphrag/index/verbs/graph/report/prepare_community_reports_nodes.py create mode 100644 func-app/graphrag/index/verbs/graph/report/restore_community_hierarchy.py create mode 100644 func-app/graphrag/index/verbs/graph/report/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py create mode 100644 func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/run_graph_intelligence.py create mode 100644 func-app/graphrag/index/verbs/graph/report/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/graph/unpack.py create mode 100644 func-app/graphrag/index/verbs/overrides/__init__.py create mode 100644 func-app/graphrag/index/verbs/overrides/aggregate.py create mode 100644 func-app/graphrag/index/verbs/overrides/concat.py create mode 100644 func-app/graphrag/index/verbs/overrides/merge.py create mode 100644 func-app/graphrag/index/verbs/snapshot.py create mode 100644 func-app/graphrag/index/verbs/snapshot_rows.py create mode 100644 func-app/graphrag/index/verbs/spread_json.py create mode 100644 func-app/graphrag/index/verbs/text/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/strategies/sentence.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/strategies/tokens.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/text_chunk.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/typing.py create mode 100644 func-app/graphrag/index/verbs/text/embed/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/embed/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/embed/strategies/mock.py create mode 100644 func-app/graphrag/index/verbs/text/embed/strategies/openai.py create mode 100644 func-app/graphrag/index/verbs/text/embed/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/text/embed/text_embed.py create mode 100644 func-app/graphrag/index/verbs/text/replace/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/replace/replace.py create mode 100644 func-app/graphrag/index/verbs/text/replace/typing.py create mode 100644 func-app/graphrag/index/verbs/text/split.py create mode 100644 func-app/graphrag/index/verbs/text/translate/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/translate/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/translate/strategies/defaults.py create mode 100644 func-app/graphrag/index/verbs/text/translate/strategies/mock.py create mode 100644 func-app/graphrag/index/verbs/text/translate/strategies/openai.py create mode 100644 func-app/graphrag/index/verbs/text/translate/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/text/translate/text_translate.py create mode 100644 func-app/graphrag/index/verbs/unzip.py create mode 100644 func-app/graphrag/index/verbs/zip.py create mode 100644 func-app/graphrag/index/workflows/__init__.py create mode 100644 func-app/graphrag/index/workflows/default_workflows.py create mode 100644 func-app/graphrag/index/workflows/load.py create mode 100644 func-app/graphrag/index/workflows/typing.py create mode 100644 func-app/graphrag/index/workflows/v1/__init__.py create mode 100644 func-app/graphrag/index/workflows/v1/create_base_documents.py create mode 100644 func-app/graphrag/index/workflows/v1/create_base_entity_graph.py create mode 100644 func-app/graphrag/index/workflows/v1/create_base_extracted_entities.py create mode 100644 func-app/graphrag/index/workflows/v1/create_base_text_units.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_communities.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_community_reports.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_covariates.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_documents.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_entities.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_nodes.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_relationships.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_text_units.py create mode 100644 func-app/graphrag/index/workflows/v1/create_summarized_entities.py create mode 100644 func-app/graphrag/index/workflows/v1/join_text_units_to_covariate_ids.py create mode 100644 func-app/graphrag/index/workflows/v1/join_text_units_to_entity_ids.py create mode 100644 func-app/graphrag/index/workflows/v1/join_text_units_to_relationship_ids.py create mode 100644 func-app/graphrag/llm/__init__.py create mode 100644 func-app/graphrag/llm/base/__init__.py create mode 100644 func-app/graphrag/llm/base/_create_cache_key.py create mode 100644 func-app/graphrag/llm/base/base_llm.py create mode 100644 func-app/graphrag/llm/base/caching_llm.py create mode 100644 func-app/graphrag/llm/base/rate_limiting_llm.py create mode 100644 func-app/graphrag/llm/errors.py create mode 100644 func-app/graphrag/llm/limiting/__init__.py create mode 100644 func-app/graphrag/llm/limiting/composite_limiter.py create mode 100644 func-app/graphrag/llm/limiting/create_limiters.py create mode 100644 func-app/graphrag/llm/limiting/llm_limiter.py create mode 100644 func-app/graphrag/llm/limiting/noop_llm_limiter.py create mode 100644 func-app/graphrag/llm/limiting/tpm_rpm_limiter.py create mode 100644 func-app/graphrag/llm/mock/__init__.py create mode 100644 func-app/graphrag/llm/mock/mock_chat_llm.py create mode 100644 func-app/graphrag/llm/mock/mock_completion_llm.py create mode 100644 func-app/graphrag/llm/openai/__init__.py create mode 100644 func-app/graphrag/llm/openai/_prompts.py create mode 100644 func-app/graphrag/llm/openai/create_openai_client.py create mode 100644 func-app/graphrag/llm/openai/factories.py create mode 100644 func-app/graphrag/llm/openai/json_parsing_llm.py create mode 100644 func-app/graphrag/llm/openai/openai_chat_llm.py create mode 100644 func-app/graphrag/llm/openai/openai_completion_llm.py create mode 100644 func-app/graphrag/llm/openai/openai_configuration.py create mode 100644 func-app/graphrag/llm/openai/openai_embeddings_llm.py create mode 100644 func-app/graphrag/llm/openai/openai_history_tracking_llm.py create mode 100644 func-app/graphrag/llm/openai/openai_token_replacing_llm.py create mode 100644 func-app/graphrag/llm/openai/types.py create mode 100644 func-app/graphrag/llm/openai/utils.py create mode 100644 func-app/graphrag/llm/types/__init__.py create mode 100644 func-app/graphrag/llm/types/llm.py create mode 100644 func-app/graphrag/llm/types/llm_cache.py create mode 100644 func-app/graphrag/llm/types/llm_callbacks.py create mode 100644 func-app/graphrag/llm/types/llm_config.py create mode 100644 func-app/graphrag/llm/types/llm_invocation_result.py create mode 100644 func-app/graphrag/llm/types/llm_io.py create mode 100644 func-app/graphrag/llm/types/llm_types.py create mode 100644 func-app/graphrag/model/__init__.py create mode 100644 func-app/graphrag/model/community.py create mode 100644 func-app/graphrag/model/community_report.py create mode 100644 func-app/graphrag/model/covariate.py create mode 100644 func-app/graphrag/model/document.py create mode 100644 func-app/graphrag/model/entity.py create mode 100644 func-app/graphrag/model/identified.py create mode 100644 func-app/graphrag/model/named.py create mode 100644 func-app/graphrag/model/relationship.py create mode 100644 func-app/graphrag/model/text_unit.py create mode 100644 func-app/graphrag/model/types.py create mode 100644 func-app/graphrag/prompt_tune/__init__.py create mode 100644 func-app/graphrag/prompt_tune/__main__.py create mode 100644 func-app/graphrag/prompt_tune/cli.py create mode 100644 func-app/graphrag/prompt_tune/generator/__init__.py create mode 100644 func-app/graphrag/prompt_tune/generator/community_report_rating.py create mode 100644 func-app/graphrag/prompt_tune/generator/community_report_summarization.py create mode 100644 func-app/graphrag/prompt_tune/generator/community_reporter_role.py create mode 100644 func-app/graphrag/prompt_tune/generator/defaults.py create mode 100644 func-app/graphrag/prompt_tune/generator/domain.py create mode 100644 func-app/graphrag/prompt_tune/generator/entity_extraction_prompt.py create mode 100644 func-app/graphrag/prompt_tune/generator/entity_relationship.py create mode 100644 func-app/graphrag/prompt_tune/generator/entity_summarization_prompt.py create mode 100644 func-app/graphrag/prompt_tune/generator/entity_types.py create mode 100644 func-app/graphrag/prompt_tune/generator/language.py create mode 100644 func-app/graphrag/prompt_tune/generator/persona.py create mode 100644 func-app/graphrag/prompt_tune/loader/__init__.py create mode 100644 func-app/graphrag/prompt_tune/loader/config.py create mode 100644 func-app/graphrag/prompt_tune/loader/input.py create mode 100644 func-app/graphrag/prompt_tune/prompt/__init__.py create mode 100644 func-app/graphrag/prompt_tune/prompt/community_report_rating.py create mode 100644 func-app/graphrag/prompt_tune/prompt/community_reporter_role.py create mode 100644 func-app/graphrag/prompt_tune/prompt/domain.py create mode 100644 func-app/graphrag/prompt_tune/prompt/entity_relationship.py create mode 100644 func-app/graphrag/prompt_tune/prompt/entity_types.py create mode 100644 func-app/graphrag/prompt_tune/prompt/language.py create mode 100644 func-app/graphrag/prompt_tune/prompt/persona.py create mode 100644 func-app/graphrag/prompt_tune/template/__init__.py create mode 100644 func-app/graphrag/prompt_tune/template/community_report_summarization.py create mode 100644 func-app/graphrag/prompt_tune/template/entity_extraction.py create mode 100644 func-app/graphrag/prompt_tune/template/entity_summarization.py create mode 100644 func-app/graphrag/query/__init__.py create mode 100644 func-app/graphrag/query/__main__.py create mode 100644 func-app/graphrag/query/cli.py create mode 100644 func-app/graphrag/query/context_builder/__init__.py create mode 100644 func-app/graphrag/query/context_builder/builders.py create mode 100644 func-app/graphrag/query/context_builder/community_context.py create mode 100644 func-app/graphrag/query/context_builder/conversation_history.py create mode 100644 func-app/graphrag/query/context_builder/entity_extraction.py create mode 100644 func-app/graphrag/query/context_builder/local_context.py create mode 100644 func-app/graphrag/query/context_builder/source_context.py create mode 100644 func-app/graphrag/query/factories.py create mode 100644 func-app/graphrag/query/indexer_adapters.py create mode 100644 func-app/graphrag/query/input/__init__.py create mode 100644 func-app/graphrag/query/input/loaders/__init__.py create mode 100644 func-app/graphrag/query/input/loaders/dfs.py create mode 100644 func-app/graphrag/query/input/loaders/utils.py create mode 100644 func-app/graphrag/query/input/retrieval/__init__.py create mode 100644 func-app/graphrag/query/input/retrieval/community_reports.py create mode 100644 func-app/graphrag/query/input/retrieval/covariates.py create mode 100644 func-app/graphrag/query/input/retrieval/entities.py create mode 100644 func-app/graphrag/query/input/retrieval/relationships.py create mode 100644 func-app/graphrag/query/input/retrieval/text_units.py create mode 100644 func-app/graphrag/query/llm/__init__.py create mode 100644 func-app/graphrag/query/llm/base.py create mode 100644 func-app/graphrag/query/llm/oai/__init__.py create mode 100644 func-app/graphrag/query/llm/oai/base.py create mode 100644 func-app/graphrag/query/llm/oai/chat_openai.py create mode 100644 func-app/graphrag/query/llm/oai/embedding.py create mode 100644 func-app/graphrag/query/llm/oai/openai.py create mode 100644 func-app/graphrag/query/llm/oai/typing.py create mode 100644 func-app/graphrag/query/llm/text_utils.py create mode 100644 func-app/graphrag/query/progress.py create mode 100644 func-app/graphrag/query/question_gen/__init__.py create mode 100644 func-app/graphrag/query/question_gen/base.py create mode 100644 func-app/graphrag/query/question_gen/local_gen.py create mode 100644 func-app/graphrag/query/question_gen/system_prompt.py create mode 100644 func-app/graphrag/query/structured_search/__init__.py create mode 100644 func-app/graphrag/query/structured_search/base.py create mode 100644 func-app/graphrag/query/structured_search/global_search/__init__.py create mode 100644 func-app/graphrag/query/structured_search/global_search/callbacks.py create mode 100644 func-app/graphrag/query/structured_search/global_search/community_context.py create mode 100644 func-app/graphrag/query/structured_search/global_search/map_system_prompt.py create mode 100644 func-app/graphrag/query/structured_search/global_search/reduce_system_prompt.py create mode 100644 func-app/graphrag/query/structured_search/global_search/search.py create mode 100644 func-app/graphrag/query/structured_search/local_search/__init__.py create mode 100644 func-app/graphrag/query/structured_search/local_search/mixed_context.py create mode 100644 func-app/graphrag/query/structured_search/local_search/search.py create mode 100644 func-app/graphrag/query/structured_search/local_search/system_prompt.py create mode 100644 func-app/graphrag/vector_stores/__init__.py create mode 100644 func-app/graphrag/vector_stores/azure_ai_search.py create mode 100644 func-app/graphrag/vector_stores/base.py create mode 100644 func-app/graphrag/vector_stores/kusto.py create mode 100644 func-app/graphrag/vector_stores/lancedb.py create mode 100644 func-app/graphrag/vector_stores/typing.py create mode 100644 func-app/host.json create mode 100644 func-app/prompts/claim_extraction.txt create mode 100644 func-app/prompts/community_report.txt create mode 100644 func-app/prompts/entity_extraction.txt create mode 100644 func-app/prompts/summarize_descriptions.txt create mode 100644 func-app/requirements.txt create mode 100644 func-app/settings.yaml create mode 100644 func-app/settings/settings.yaml diff --git a/func-app/.gitignore b/func-app/.gitignore new file mode 100644 index 0000000000..c9d14718e1 --- /dev/null +++ b/func-app/.gitignore @@ -0,0 +1,49 @@ +bin +obj +csx +.vs +edge +Publish + +*.user +*.suo +*.cscfg +*.Cache +project.lock.json + +/packages +/TestResults + +/tools/NuGet.exe +/App_Data +/secrets +/data +.secrets +appsettings.json +local.settings.json + +node_modules +dist + +# Local python packages +.python_packages/ + +# Python Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Azurite artifacts +__blobstorage__ +__queuestorage__ +__azurite_db*__.json + diff --git a/func-app/.vscode/launch.json b/func-app/.vscode/launch.json new file mode 100644 index 0000000000..a90b7259e1 --- /dev/null +++ b/func-app/.vscode/launch.json @@ -0,0 +1,13 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Attach to Python Functions", + "type": "python", + "request": "attach", + "port": 7071, + "preLaunchTask": "func: host start", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/func-app/.vscode/settings.json b/func-app/.vscode/settings.json new file mode 100644 index 0000000000..80db5bb047 --- /dev/null +++ b/func-app/.vscode/settings.json @@ -0,0 +1,8 @@ +{ + "azureFunctions.deploySubpath": ".", + "azureFunctions.scmDoBuildDuringDeployment": true, + "azureFunctions.pythonVenv": ".venv", + "azureFunctions.projectLanguage": "Python", + "azureFunctions.projectRuntime": "~4", + "debug.internalConsoleOptions": "neverOpen", +} \ No newline at end of file diff --git a/func-app/.vscode/tasks.json b/func-app/.vscode/tasks.json new file mode 100644 index 0000000000..808884468c --- /dev/null +++ b/func-app/.vscode/tasks.json @@ -0,0 +1,26 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "type": "func", + "command": "host start", + "problemMatcher": "$func-watch", + "isBackground": true, + "dependsOn": "pipInstall" + }, + { + "label": "pipInstall", + "type": "shell", + "osx": { + "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" + }, + "windows": { + "command": "${config:azureFunctions.pythonVenv}\\Scripts\\python -m pip install -r requirements.txt" + }, + "linux": { + "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" + }, + "problemMatcher": [] + } + ] +} \ No newline at end of file diff --git a/func-app/common/graph_db_client.py b/func-app/common/graph_db_client.py new file mode 100644 index 0000000000..ab4503f01c --- /dev/null +++ b/func-app/common/graph_db_client.py @@ -0,0 +1,158 @@ +import os +import pandas as pd + +from graphrag.config.models.graphdb_config import GraphDBConfig +import numpy as np + +import ast + +from gremlin_python.driver import client, serializer +from azure.identity import ManagedIdentityCredential + +import time +import os +import json + +# Azure Cosmos DB Gremlin Endpoint and other constants +COSMOS_DB_SCOPE = "https://cosmos.azure.com/.default" # The scope for Cosmos DB +class GraphDBClient: + def __init__(self,graph_db_params: GraphDBConfig|None,context_id: str|None): + self.username_prefix=graph_db_params.username + token = f"{graph_db_params.account_key}" + #if(os.environ.get("ENVIRONMENT") == "AZURE"): + # credential = ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3") + # token = credential.get_token(COSMOS_DB_SCOPE) + self._client=client.Client( + url=f"{graph_db_params.gremlin_url}", + traversal_source="g", + username=self.username_prefix+"-contextid-"+context_id, + password=token, + message_serializer=serializer.GraphSONSerializersV2d0(), + ) + + def result_to_df(self,result) -> pd.DataFrame: + json_data = [] + for row in result: + json_row = row[0] + properties_dict = json_row.pop('properties') + formatted_properties={} + for k,v in properties_dict.items(): + new_val=v + if isinstance(v,list) and isinstance(v[0],dict): + new_val=v[0]['value'] + if k=='description_embedding' or k =='text_unit_ids' or k=='graph_embedding': + new_val=ast.literal_eval(new_val) + if isinstance(new_val,list): + new_val=np.array(new_val) + formatted_properties[k]=new_val + json_row.update(formatted_properties) + json_data.append(json_row) + df = pd.DataFrame(json_data) + return df + + def remove_graph(self): + self._client.submit(message=("g.V().drop()")) + + def query_vertices(self,context_id:str) -> pd.DataFrame: + result = self._client.submit( + message=( + "g.V()" + ), + ) + return self.result_to_df(result) + + def query_edges(self,context_id:str) -> pd.DataFrame: + result = self._client.submit( + message=( + "g.E()" + ), + ) + return self.result_to_df(result) + + def element_exists(self,element_type:str,element_id:int,conditions:str="")->bool: + result=self._client.submit( + message=( + element_type+ + ".has('id',prop_id)"+ + conditions+ + ".count()" + ), + bindings={ + "prop_id":element_id, + } + ) + element_count=0 + for counts in result: + element_count=counts[0] + return element_count>0 + + def write_vertices(self,data: pd.DataFrame)->None: + for row in data.itertuples(): + if self.element_exists("g.V()",row.id): + continue + else: + self._client.submit( + message=( + "g.addV('entity')" + ".property('id', prop_id)" + ".property('name', prop_name)" + ".property('type', prop_type)" + ".property('description','prop_description')" + ".property('human_readable_id', prop_human_readable_id)" + ".property('category', prop_partition_key)" + ".property(list,'description_embedding',prop_description_embedding)" + ".property(list,'graph_embedding',prop_graph_embedding)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ), + bindings={ + "prop_id": row.id, + "prop_name": row.name, + "prop_type": row.type, + "prop_description": row.description, + "prop_human_readable_id": row.human_readable_id, + "prop_partition_key": "entities", + "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), + "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), + "prop_text_unit_ids":json.dumps(row.text_unit_ids.tolist() if row.text_unit_ids is not None else []), + }, + ) + time.sleep(5) + + + def write_edges(self,data: pd.DataFrame)->None: + for row in data.itertuples(): + if self.element_exists("g.E()",row.id): + continue + self._client.submit( + message=( + "g.V().has('name',prop_source_id)" + ".addE('connects')" + ".to(g.V().has('name',prop_target_id))" + ".property('weight',prop_weight)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ".property('description',prop_description)" + ".property('id',prop_id)" + ".property('human_readable_id',prop_human_readable_id)" + ".property('source_degree',prop_source_degree)" + ".property('target_degree',prop_target_degree)" + ".property('rank',prop_rank)" + ".property('source',prop_source)" + ".property('target',prop_target)" + ), + bindings={ + "prop_partition_key": "entities", + "prop_source_id": row.source, + "prop_target_id": row.target, + "prop_weight": row.weight, + "prop_text_unit_ids":json.dumps(row.text_unit_ids.tolist() if row.text_unit_ids is not None else []), + "prop_description": row.description, + "prop_id": row.id, + "prop_human_readable_id": row.human_readable_id, + "prop_source_degree": row.source_degree, + "prop_target_degree": row.target_degree, + "prop_rank": row.rank, + "prop_source": row.source, + "prop_target": row.target, + }, + ) + time.sleep(5) \ No newline at end of file diff --git a/func-app/function_app.py b/func-app/function_app.py new file mode 100644 index 0000000000..38df93461f --- /dev/null +++ b/func-app/function_app.py @@ -0,0 +1,37 @@ +import azure.functions as func +import datetime +import json +import logging +import csv +import codecs +from graphrag.index.cli import index_cli + +app = func.FunctionApp() + +@app.function_name('IndexingPipelineFunc') +@app.route(route="index", auth_level=func.AuthLevel.ANONYMOUS) +def indexing(req: func.HttpRequest) -> func.HttpResponse: + logging.info('Python HTTP trigger function processed a request.') + + index_cli( + root = "", + verbose=False, + resume=False, + memprofile=False, + nocache=False, + config=None, + emit=None, + dryrun=False, + init=True, + overlay_defaults=False, + cli=True, + context_id=None, + context_operation=None, + community_level=None, + use_kusto_community_reports=None, + optimized_search=None + ) + return func.HttpResponse( + "Wow this first HTTP Function works!!!!", + status_code=200 + ) diff --git a/func-app/graphrag/__init__.py b/func-app/graphrag/__init__.py new file mode 100644 index 0000000000..a1e9b589bf --- /dev/null +++ b/func-app/graphrag/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The GraphRAG package.""" diff --git a/func-app/graphrag/common/blob_storage_client.py b/func-app/graphrag/common/blob_storage_client.py new file mode 100644 index 0000000000..78dc809579 --- /dev/null +++ b/func-app/graphrag/common/blob_storage_client.py @@ -0,0 +1,58 @@ +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient + + +class BlobStorageClient: + """The Blob-Storage implementation.""" + + _connection_string: str | None + _container_name: str + _path_prefix: str + _encoding: str + _storage_account_blob_url: str | None + + def __init__( + self, + connection_string: str | None, + container_name: str, + encoding: str | None = None, + path_prefix: str | None = None, + storage_account_blob_url: str | None = None, + ): + """Create a new BlobStorage instance.""" + if connection_string: + self._blob_service_client = BlobServiceClient.from_connection_string( + connection_string + ) + else: + if storage_account_blob_url is None: + msg = "Either connection_string or storage_account_blob_url must be provided." + raise ValueError(msg) + + self._blob_service_client = BlobServiceClient( + account_url=storage_account_blob_url, + credential=DefaultAzureCredential(), + ) + self._encoding = encoding or "utf-8" + self._container_name = container_name + self._connection_string = connection_string + self._path_prefix = path_prefix or "" + self._storage_account_blob_url = storage_account_blob_url + self._storage_account_name = ( + storage_account_blob_url.split("//")[1].split(".")[0] + if storage_account_blob_url + else None + ) + #log.info( + # "creating blob storage at container=%s, path=%s", + # self._container_name, + # self._path_prefix, + #) + + def get_blob_service_client(self): + """Get the BlobServiceClient instance.""" + return self._blob_service_client + + def get_container_client(self): + """Get the container client instance.""" + return self._blob_service_client.get_container_client(self._container_name) diff --git a/func-app/graphrag/common/config/storage.py b/func-app/graphrag/common/config/storage.py new file mode 100644 index 0000000000..023d50e249 --- /dev/null +++ b/func-app/graphrag/common/config/storage.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineStorageConfig', 'PipelineFileStorageConfig' and 'PipelineMemoryStorageConfig' models.""" + +from __future__ import annotations + +from typing import Generic, Literal, TypeVar + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +from graphrag.config.enums import StorageType + +T = TypeVar("T") + + +class PipelineStorageConfig(BaseModel, Generic[T]): + """Represent the storage configuration for the pipeline.""" + + type: T + + +class PipelineFileStorageConfig(PipelineStorageConfig[Literal[StorageType.file]]): + """Represent the file storage configuration for the pipeline.""" + + type: Literal[StorageType.file] = StorageType.file + """The type of storage.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the storage.", default=None + ) + """The base directory for the storage.""" + + +class PipelineMemoryStorageConfig(PipelineStorageConfig[Literal[StorageType.memory]]): + """Represent the memory storage configuration for the pipeline.""" + + type: Literal[StorageType.memory] = StorageType.memory + """The type of storage.""" + + +class PipelineBlobStorageConfig(PipelineStorageConfig[Literal[StorageType.blob]]): + """Represents the blob storage configuration for the pipeline.""" + + type: Literal[StorageType.blob] = StorageType.blob + """The type of storage.""" + + connection_string: str | None = pydantic_Field( + description="The blob storage connection string for the storage.", default=None + ) + """The blob storage connection string for the storage.""" + + container_name: str = pydantic_Field( + description="The container name for storage", default=None + ) + """The container name for storage.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the storage.", default=None + ) + """The base directory for the storage.""" + + storage_account_blob_url: str | None = pydantic_Field( + description="The storage account blob url.", default=None + ) + """The storage account blob url.""" + + +PipelineStorageConfigTypes = ( + PipelineFileStorageConfig | PipelineMemoryStorageConfig | PipelineBlobStorageConfig +) diff --git a/func-app/graphrag/common/graph_db_client.py b/func-app/graphrag/common/graph_db_client.py new file mode 100644 index 0000000000..35b70dd385 --- /dev/null +++ b/func-app/graphrag/common/graph_db_client.py @@ -0,0 +1 @@ +# create Gremlin and cosmos db clients by reading settings from settings.yaml \ No newline at end of file diff --git a/func-app/graphrag/common/kusto_db_client.py b/func-app/graphrag/common/kusto_db_client.py new file mode 100644 index 0000000000..413a47341e --- /dev/null +++ b/func-app/graphrag/common/kusto_db_client.py @@ -0,0 +1 @@ +# create Gremlin and kusto db clients by reading settings from settings.yaml \ No newline at end of file diff --git a/func-app/graphrag/common/progress/__init__.py b/func-app/graphrag/common/progress/__init__.py new file mode 100644 index 0000000000..df6a21523d --- /dev/null +++ b/func-app/graphrag/common/progress/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Progress-reporting components.""" + +from .types import NullProgressReporter, PrintProgressReporter, ProgressReporter + +__all__ = ["NullProgressReporter", "PrintProgressReporter", "ProgressReporter"] diff --git a/func-app/graphrag/common/progress/rich.py b/func-app/graphrag/common/progress/rich.py new file mode 100644 index 0000000000..362b64f0c8 --- /dev/null +++ b/func-app/graphrag/common/progress/rich.py @@ -0,0 +1,165 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Rich-based progress reporter for CLI use.""" + +# Print iterations progress +import asyncio + +from datashaper import Progress as DSProgress +from rich.console import Console, Group +from rich.live import Live +from rich.progress import Progress, TaskID, TimeElapsedColumn +from rich.spinner import Spinner +from rich.tree import Tree + +from .types import ProgressReporter + + +# https://stackoverflow.com/a/34325723 +class RichProgressReporter(ProgressReporter): + """A rich-based progress reporter for CLI use.""" + + _console: Console + _group: Group + _tree: Tree + _live: Live + _task: TaskID | None = None + _prefix: str + _transient: bool + _disposing: bool = False + _progressbar: Progress + _last_refresh: float = 0 + + def dispose(self) -> None: + """Dispose of the progress reporter.""" + self._disposing = True + self._live.stop() + + @property + def console(self) -> Console: + """Get the console.""" + return self._console + + @property + def group(self) -> Group: + """Get the group.""" + return self._group + + @property + def tree(self) -> Tree: + """Get the tree.""" + return self._tree + + @property + def live(self) -> Live: + """Get the live.""" + return self._live + + def __init__( + self, + prefix: str, + parent: "RichProgressReporter | None" = None, + transient: bool = True, + ) -> None: + """Create a new rich-based progress reporter.""" + self._prefix = prefix + + if parent is None: + console = Console() + group = Group(Spinner("dots", prefix), fit=True) + tree = Tree(group) + live = Live( + tree, console=console, refresh_per_second=1, vertical_overflow="crop" + ) + live.start() + + self._console = console + self._group = group + self._tree = tree + self._live = live + self._transient = False + else: + self._console = parent.console + self._group = parent.group + progress_columns = [*Progress.get_default_columns(), TimeElapsedColumn()] + self._progressbar = Progress( + *progress_columns, console=self._console, transient=transient + ) + + tree = Tree(prefix) + tree.add(self._progressbar) + tree.hide_root = True + + if parent is not None: + parent_tree = parent.tree + parent_tree.hide_root = False + parent_tree.add(tree) + + self._tree = tree + self._live = parent.live + self._transient = transient + + self.refresh() + + def refresh(self) -> None: + """Perform a debounced refresh.""" + now = asyncio.get_event_loop().time() + duration = now - self._last_refresh + if duration > 0.1: + self._last_refresh = now + self.force_refresh() + + def force_refresh(self) -> None: + """Force a refresh.""" + self.live.refresh() + + def stop(self) -> None: + """Stop the progress reporter.""" + self._live.stop() + + def child(self, prefix: str, transient: bool = True) -> ProgressReporter: + """Create a child progress bar.""" + return RichProgressReporter(parent=self, prefix=prefix, transient=transient) + + def error(self, message: str) -> None: + """Report an error.""" + self._console.print(f"❌ [red]{message}[/red]") + + def warning(self, message: str) -> None: + """Report a warning.""" + self._console.print(f"⚠️ [yellow]{message}[/yellow]") + + def success(self, message: str) -> None: + """Report success.""" + self._console.print(f"🚀 [green]{message}[/green]") + + def info(self, message: str) -> None: + """Report information.""" + self._console.print(message) + + def __call__(self, progress_update: DSProgress) -> None: + """Update progress.""" + if self._disposing: + return + progressbar = self._progressbar + + if self._task is None: + self._task = progressbar.add_task(self._prefix) + + progress_description = "" + if progress_update.description is not None: + progress_description = f" - {progress_update.description}" + + completed = progress_update.completed_items or progress_update.percent + total = progress_update.total_items or 1 + progressbar.update( + self._task, + completed=completed, + total=total, + description=f"{self._prefix}{progress_description}", + ) + if completed == total and self._transient: + progressbar.update(self._task, visible=False) + + self.refresh() diff --git a/func-app/graphrag/common/progress/types.py b/func-app/graphrag/common/progress/types.py new file mode 100644 index 0000000000..2912155ed1 --- /dev/null +++ b/func-app/graphrag/common/progress/types.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Types for status reporting.""" + +from abc import ABC, abstractmethod + +from datashaper import Progress + + +class ProgressReporter(ABC): + """ + Abstract base class for progress reporters. + + This is used to report workflow processing progress via mechanisms like progress-bars. + """ + + @abstractmethod + def __call__(self, update: Progress): + """Update progress.""" + + @abstractmethod + def dispose(self): + """Dispose of the progress reporter.""" + + @abstractmethod + def child(self, prefix: str, transient=True) -> "ProgressReporter": + """Create a child progress bar.""" + + @abstractmethod + def force_refresh(self) -> None: + """Force a refresh.""" + + @abstractmethod + def stop(self) -> None: + """Stop the progress reporter.""" + + @abstractmethod + def error(self, message: str) -> None: + """Report an error.""" + + @abstractmethod + def warning(self, message: str) -> None: + """Report a warning.""" + + @abstractmethod + def info(self, message: str) -> None: + """Report information.""" + + @abstractmethod + def success(self, message: str) -> None: + """Report success.""" + + +class NullProgressReporter(ProgressReporter): + """A progress reporter that does nothing.""" + + def __call__(self, update: Progress) -> None: + """Update progress.""" + + def dispose(self) -> None: + """Dispose of the progress reporter.""" + + def child(self, prefix: str, transient: bool = True) -> ProgressReporter: + """Create a child progress bar.""" + return self + + def force_refresh(self) -> None: + """Force a refresh.""" + + def stop(self) -> None: + """Stop the progress reporter.""" + + def error(self, message: str) -> None: + """Report an error.""" + + def warning(self, message: str) -> None: + """Report a warning.""" + + def info(self, message: str) -> None: + """Report information.""" + + def success(self, message: str) -> None: + """Report success.""" + + +class PrintProgressReporter(ProgressReporter): + """A progress reporter that does nothing.""" + + prefix: str + + def __init__(self, prefix: str): + """Create a new progress reporter.""" + self.prefix = prefix + print(f"\n{self.prefix}", end="") # noqa T201 + + def __call__(self, update: Progress) -> None: + """Update progress.""" + print(".", end="") # noqa T201 + + def dispose(self) -> None: + """Dispose of the progress reporter.""" + + def child(self, prefix: str, transient: bool = True) -> "ProgressReporter": + """Create a child progress bar.""" + return PrintProgressReporter(prefix) + + def stop(self) -> None: + """Stop the progress reporter.""" + + def force_refresh(self) -> None: + """Force a refresh.""" + + def error(self, message: str) -> None: + """Report an error.""" + print(f"\n{self.prefix}ERROR: {message}") # noqa T201 + + def warning(self, message: str) -> None: + """Report a warning.""" + print(f"\n{self.prefix}WARNING: {message}") # noqa T201 + + def info(self, message: str) -> None: + """Report information.""" + print(f"\n{self.prefix}INFO: {message}") # noqa T201 + + def success(self, message: str) -> None: + """Report success.""" + print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201 diff --git a/func-app/graphrag/common/storage/__init__.py b/func-app/graphrag/common/storage/__init__.py new file mode 100644 index 0000000000..7ca943db52 --- /dev/null +++ b/func-app/graphrag/common/storage/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine storage package root.""" + +from .blob_pipeline_storage import BlobPipelineStorage, create_blob_storage +from .file_pipeline_storage import FilePipelineStorage +from .load_storage import load_storage +from .memory_pipeline_storage import MemoryPipelineStorage +from .typing import PipelineStorage + +__all__ = [ + "BlobPipelineStorage", + "FilePipelineStorage", + "MemoryPipelineStorage", + "PipelineStorage", + "create_blob_storage", + "load_storage", +] diff --git a/func-app/graphrag/common/storage/blob_pipeline_storage.py b/func-app/graphrag/common/storage/blob_pipeline_storage.py new file mode 100644 index 0000000000..568ec89bcc --- /dev/null +++ b/func-app/graphrag/common/storage/blob_pipeline_storage.py @@ -0,0 +1,375 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Azure Blob Storage implementation of PipelineStorage.""" + +import logging +import re +from collections.abc import Iterator +from pathlib import Path +from typing import Any + +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from datashaper import Progress + +from graphrag.common.progress import ProgressReporter + +from .typing import PipelineStorage + +log = logging.getLogger(__name__) + + +class BlobPipelineStorage(PipelineStorage): + """The Blob-Storage implementation.""" + + _connection_string: str | None + _container_name: str + _path_prefix: str + _encoding: str + _storage_account_blob_url: str | None + + def __init__( + self, + connection_string: str | None, + container_name: str, + encoding: str | None = None, + path_prefix: str | None = None, + storage_account_blob_url: str | None = None, + overwrite: bool = False + ): + """Create a new BlobStorage instance.""" + if connection_string: + self._blob_service_client = BlobServiceClient.from_connection_string( + connection_string + ) + else: + if storage_account_blob_url is None: + msg = "Either connection_string or storage_account_blob_url must be provided." + raise ValueError(msg) + + self._blob_service_client = BlobServiceClient( + account_url=storage_account_blob_url, + credential=DefaultAzureCredential(managed_identity_client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), + ) + self._encoding = encoding or "utf-8" + self._container_name = container_name + self._connection_string = connection_string + self._overwrite = overwrite + self._path_prefix = path_prefix or "" + self._storage_account_blob_url = storage_account_blob_url + self._storage_account_name = ( + storage_account_blob_url.split("//")[1].split(".")[0] + if storage_account_blob_url + else None + ) + log.info( + "creating blob storage at container=%s, path=%s", + self._container_name, + self._path_prefix, + ) + self.create_container() + + def create_container(self) -> None: + """Create the container if it does not exist.""" + if not self.container_exists(): + container_name = self._container_name + container_names = [ + container.name + for container in self._blob_service_client.list_containers() + ] + if container_name not in container_names: + self._blob_service_client.create_container(container_name) + + def delete_container(self) -> None: + """Delete the container.""" + if self.container_exists(): + self._blob_service_client.delete_container(self._container_name) + + def container_exists(self) -> bool: + """Check if the container exists.""" + container_name = self._container_name + container_names = [ + container.name for container in self._blob_service_client.list_containers() + ] + return container_name in container_names + + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + progress: ProgressReporter | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find blobs in a container using a file pattern, as well as a custom filter function. + + Params: + base_dir: The name of the base container. + file_pattern: The file pattern to use. + file_filter: A dictionary of key-value pairs to filter the blobs. + max_count: The maximum number of blobs to return. If -1, all blobs are returned. + + Returns + ------- + An iterator of blob names and their corresponding regex matches. + """ + base_dir = base_dir or "" + + log.info( + "search container %s for files matching %s", + self._container_name, + file_pattern.pattern, + ) + + def blobname(blob_name: str) -> str: + if blob_name.startswith(self._path_prefix): + blob_name = blob_name.replace(self._path_prefix, "", 1) + if blob_name.startswith("/"): + blob_name = blob_name[1:] + return blob_name + + def item_filter(item: dict[str, Any]) -> bool: + if file_filter is None: + return True + + return all(re.match(value, item[key]) for key, value in file_filter.items()) + + try: + container_client = self._blob_service_client.get_container_client( + self._container_name + ) + all_blobs = list(container_client.list_blobs()) + + num_loaded = 0 + num_total = len(list(all_blobs)) + num_filtered = 0 + for blob in all_blobs: + match = file_pattern.match(blob.name) + if match and blob.name.startswith(base_dir): + group = match.groupdict() + if item_filter(group): + yield (blobname(blob.name), group) + num_loaded += 1 + if max_count > 0 and num_loaded >= max_count: + break + else: + num_filtered += 1 + else: + num_filtered += 1 + if progress is not None: + progress( + _create_progress_status(num_loaded, num_filtered, num_total) + ) + except Exception: + log.exception( + "Error finding blobs: base_dir=%s, file_pattern=%s, file_filter=%s", + base_dir, + file_pattern, + file_filter, + ) + raise + + async def get( + self, key: str, as_bytes: bool | None = False, encoding: str | None = None + ) -> Any: + """Get a value from the cache.""" + try: + key = self._keyname(key) + container_client = self._blob_service_client.get_container_client( + self._container_name + ) + blob_client = container_client.get_blob_client(key) + blob_data = blob_client.download_blob().readall() + if not as_bytes: + coding = encoding or "utf-8" + blob_data = blob_data.decode(coding) + except Exception: + log.exception("Error getting key %s", key) + return None + else: + return blob_data + + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: + """Set a value in the cache.""" + try: + key = self._keyname(key) + container_client = self._blob_service_client.get_container_client( + self._container_name + ) + blob_client = container_client.get_blob_client(key) + if blob_client.exists() and not self._overwrite: + ValueError("Artifacts already exists, make sure output folder is empty.") + if isinstance(value, bytes): + blob_client.upload_blob(value, overwrite=True) + else: + coding = encoding or "utf-8" + blob_client.upload_blob(value.encode(coding), overwrite=True) + except Exception: + log.exception("Error setting key %s: %s", key) + + def set_df_json(self, key: str, dataframe: Any) -> None: + """Set a json dataframe.""" + if self._connection_string is None and self._storage_account_name: + dataframe.to_json( + self._abfs_url(key), + storage_options={ + "account_name": self._storage_account_name, + "credential": DefaultAzureCredential(), + }, + orient="records", + lines=True, + force_ascii=False, + ) + else: + dataframe.to_json( + self._abfs_url(key), + storage_options={"connection_string": self._connection_string}, + orient="records", + lines=True, + force_ascii=False, + ) + + def set_df_parquet(self, key: str, dataframe: Any) -> None: + """Set a parquet dataframe.""" + if self._connection_string is None and self._storage_account_name: + dataframe.to_parquet( + self._abfs_url(key), + storage_options={ + "account_name": self._storage_account_name, + "credential": DefaultAzureCredential(), + }, + ) + else: + dataframe.to_parquet( + self._abfs_url(key), + storage_options={"connection_string": self._connection_string}, + ) + + async def has(self, key: str) -> bool: + """Check if a key exists in the cache.""" + key = self._keyname(key) + container_client = self._blob_service_client.get_container_client( + self._container_name + ) + blob_client = container_client.get_blob_client(key) + return blob_client.exists() + + async def delete(self, key: str) -> None: + """Delete a key from the cache.""" + key = self._keyname(key) + container_client = self._blob_service_client.get_container_client( + self._container_name + ) + blob_client = container_client.get_blob_client(key) + blob_client.delete_blob() + + async def clear(self) -> None: + """Clear the cache.""" + + def child(self, name: str | None) -> "PipelineStorage": + """Create a child storage instance.""" + if name is None: + return self + path = str(Path(self._path_prefix) / name) + return BlobPipelineStorage( + self._connection_string, + self._container_name, + self._encoding, + path, + self._storage_account_blob_url, + ) + + def _keyname(self, key: str) -> str: + """Get the key name.""" + return str(Path(self._path_prefix) / key) + + def _abfs_url(self, key: str) -> str: + """Get the ABFS URL.""" + path = str(Path(self._container_name) / self._path_prefix / key) + return f"abfs://{path}" + + +def create_blob_storage( + connection_string: str | None, + storage_account_blob_url: str | None, + container_name: str, + base_dir: str | None, +) -> PipelineStorage: + """Create a blob based storage.""" + log.info("Creating blob storage at %s", container_name) + if container_name is None: + msg = "No container name provided for blob storage." + raise ValueError(msg) + if connection_string is None and storage_account_blob_url is None: + msg = "No storage account blob url provided for blob storage." + raise ValueError(msg) + return BlobPipelineStorage( + connection_string, + container_name, + path_prefix=base_dir, + storage_account_blob_url=storage_account_blob_url, + ) + + +def validate_blob_container_name(container_name: str): + """ + Check if the provided blob container name is valid based on Azure rules. + + - A blob container name must be between 3 and 63 characters in length. + - Start with a letter or number + - All letters used in blob container names must be lowercase. + - Contain only letters, numbers, or the hyphen. + - Consecutive hyphens are not permitted. + - Cannot end with a hyphen. + + Args: + ----- + container_name (str) + The blob container name to be validated. + + Returns + ------- + bool: True if valid, False otherwise. + """ + # Check the length of the name + if len(container_name) < 3 or len(container_name) > 63: + return ValueError( + f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long." + ) + + # Check if the name starts with a letter or number + if not container_name[0].isalnum(): + return ValueError( + f"Container name must start with a letter or number. Starting character was {container_name[0]}." + ) + + # Check for valid characters (letters, numbers, hyphen) and lowercase letters + if not re.match("^[a-z0-9-]+$", container_name): + return ValueError( + f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}." + ) + + # Check for consecutive hyphens + if "--" in container_name: + return ValueError( + f"Container name cannot contain consecutive hyphens. Name provided was {container_name}." + ) + + # Check for hyphens at the end of the name + if container_name[-1] == "-": + return ValueError( + f"Container name cannot end with a hyphen. Name provided was {container_name}." + ) + + return True + + +def _create_progress_status( + num_loaded: int, num_filtered: int, num_total: int +) -> Progress: + return Progress( + total_items=num_total, + completed_items=num_loaded + num_filtered, + description=f"{num_loaded} files loaded ({num_filtered} filtered)", + ) diff --git a/func-app/graphrag/common/storage/file_pipeline_storage.py b/func-app/graphrag/common/storage/file_pipeline_storage.py new file mode 100644 index 0000000000..212783e41f --- /dev/null +++ b/func-app/graphrag/common/storage/file_pipeline_storage.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'FileStorage' and 'FilePipelineStorage' models.""" + +import logging +import os +import re +import shutil +from collections.abc import Iterator +from pathlib import Path +from typing import Any, cast + +import aiofiles +from aiofiles.os import remove +from aiofiles.ospath import exists +from datashaper import Progress + +from graphrag.common.progress import ProgressReporter + +from .typing import PipelineStorage + +log = logging.getLogger(__name__) + + +class FilePipelineStorage(PipelineStorage): + """File storage class definition.""" + + _root_dir: str + _encoding: str + + def __init__(self, root_dir: str | None = None, encoding: str | None = None): + """Init method definition.""" + self._root_dir = root_dir or "" + self._encoding = encoding or "utf-8" + Path(self._root_dir).mkdir(parents=True, exist_ok=True) + + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + progress: ProgressReporter | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find files in the storage using a file pattern, as well as a custom filter function.""" + + def item_filter(item: dict[str, Any]) -> bool: + if file_filter is None: + return True + + return all(re.match(value, item[key]) for key, value in file_filter.items()) + + search_path = Path(self._root_dir) / (base_dir or "") + log.info("search %s for files matching %s", search_path, file_pattern.pattern) + all_files = list(search_path.rglob("**/*")) + num_loaded = 0 + num_total = len(all_files) + num_filtered = 0 + for file in all_files: + match = file_pattern.match(f"{file}") + if match: + group = match.groupdict() + if item_filter(group): + filename = f"{file}".replace(self._root_dir, "") + if filename.startswith(os.sep): + filename = filename[1:] + yield (filename, group) + num_loaded += 1 + if max_count > 0 and num_loaded >= max_count: + break + else: + num_filtered += 1 + else: + num_filtered += 1 + if progress is not None: + progress(_create_progress_status(num_loaded, num_filtered, num_total)) + + async def get( + self, key: str, as_bytes: bool | None = False, encoding: str | None = None + ) -> Any: + """Get method definition.""" + file_path = join_path(self._root_dir, key) + + if await self.has(key): + return await self._read_file(file_path, as_bytes, encoding) + if await exists(key): + # Lookup for key, as it is pressumably a new file loaded from inputs + # and not yet written to storage + return await self._read_file(key, as_bytes, encoding) + + return None + + async def _read_file( + self, + path: str | Path, + as_bytes: bool | None = False, + encoding: str | None = None, + ) -> Any: + """Read the contents of a file.""" + read_type = "rb" if as_bytes else "r" + encoding = None if as_bytes else (encoding or self._encoding) + + async with aiofiles.open( + path, + cast(Any, read_type), + encoding=encoding, + ) as f: + return await f.read() + + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: + """Set method definition.""" + is_bytes = isinstance(value, bytes) + write_type = "wb" if is_bytes else "w" + encoding = None if is_bytes else encoding or self._encoding + os.makedirs(os.path.dirname(join_path(self._root_dir, key)), mode=777, exist_ok=True) + async with aiofiles.open( + join_path(self._root_dir, key), + cast(Any, write_type), + encoding=encoding, + ) as f: + await f.write(value) + + async def has(self, key: str) -> bool: + """Has method definition.""" + return await exists(join_path(self._root_dir, key)) + + async def delete(self, key: str) -> None: + """Delete method definition.""" + if await self.has(key): + await remove(join_path(self._root_dir, key)) + + async def clear(self) -> None: + """Clear method definition.""" + for file in Path(self._root_dir).glob("*"): + if file.is_dir(): + shutil.rmtree(file) + else: + file.unlink() + + def child(self, name: str | None) -> "PipelineStorage": + """Create a child storage instance.""" + if name is None: + return self + return FilePipelineStorage(str(Path(self._root_dir) / Path(name))) + + +def join_path(file_path: str, file_name: str) -> Path: + """Join a path and a file. Independent of the OS.""" + return Path(file_path) / Path(file_name).parent / Path(file_name).name + + +def create_file_storage(out_dir: str | None) -> PipelineStorage: + """Create a file based storage.""" + log.info("Creating file storage at %s", out_dir) + return FilePipelineStorage(out_dir) + + +def _create_progress_status( + num_loaded: int, num_filtered: int, num_total: int +) -> Progress: + return Progress( + total_items=num_total, + completed_items=num_loaded + num_filtered, + description=f"{num_loaded} files loaded ({num_filtered} filtered)", + ) diff --git a/func-app/graphrag/common/storage/load_storage.py b/func-app/graphrag/common/storage/load_storage.py new file mode 100644 index 0000000000..24a6675a04 --- /dev/null +++ b/func-app/graphrag/common/storage/load_storage.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load_storage method definition.""" + +from __future__ import annotations + +from typing import cast + +from graphrag.config import StorageType +from graphrag.common.config.storage import ( + PipelineBlobStorageConfig, + PipelineFileStorageConfig, + PipelineStorageConfig, +) + +from .blob_pipeline_storage import create_blob_storage +from .file_pipeline_storage import create_file_storage +from .memory_pipeline_storage import create_memory_storage + + +def load_storage(config: PipelineStorageConfig): + """Load the storage for a pipeline.""" + match config.type: + case StorageType.memory: + return create_memory_storage() + case StorageType.blob: + config = cast(PipelineBlobStorageConfig, config) + return create_blob_storage( + config.connection_string, + config.storage_account_blob_url, + config.container_name, + config.base_dir, + ) + case StorageType.file: + config = cast(PipelineFileStorageConfig, config) + return create_file_storage(config.base_dir) + case _: + msg = f"Unknown storage type: {config.type}" + raise ValueError(msg) diff --git a/func-app/graphrag/common/storage/memory_pipeline_storage.py b/func-app/graphrag/common/storage/memory_pipeline_storage.py new file mode 100644 index 0000000000..2d1382e0af --- /dev/null +++ b/func-app/graphrag/common/storage/memory_pipeline_storage.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'InMemoryStorage' model.""" + +from typing import Any + +from .file_pipeline_storage import FilePipelineStorage +from .typing import PipelineStorage + + +class MemoryPipelineStorage(FilePipelineStorage): + """In memory storage class definition.""" + + _storage: dict[str, Any] + + def __init__(self): + """Init method definition.""" + super().__init__(root_dir=".output") + self._storage = {} + + async def get( + self, key: str, as_bytes: bool | None = None, encoding: str | None = None + ) -> Any: + """Get the value for the given key. + + Args: + - key - The key to get the value for. + - as_bytes - Whether or not to return the value as bytes. + + Returns + ------- + - output - The value for the given key. + """ + return self._storage.get(key) or await super().get(key, as_bytes, encoding) + + async def set( + self, key: str, value: str | bytes | None, encoding: str | None = None + ) -> None: + """Set the value for the given key. + + Args: + - key - The key to set the value for. + - value - The value to set. + """ + self._storage[key] = value + + async def has(self, key: str) -> bool: + """Return True if the given key exists in the storage. + + Args: + - key - The key to check for. + + Returns + ------- + - output - True if the key exists in the storage, False otherwise. + """ + return key in self._storage or await super().has(key) + + async def delete(self, key: str) -> None: + """Delete the given key from the storage. + + Args: + - key - The key to delete. + """ + del self._storage[key] + + async def clear(self) -> None: + """Clear the storage.""" + self._storage.clear() + + def child(self, name: str | None) -> "PipelineStorage": + """Create a child storage instance.""" + return self + + +def create_memory_storage() -> PipelineStorage: + """Create memory storage.""" + return MemoryPipelineStorage() diff --git a/func-app/graphrag/common/storage/typing.py b/func-app/graphrag/common/storage/typing.py new file mode 100644 index 0000000000..c5f0de3265 --- /dev/null +++ b/func-app/graphrag/common/storage/typing.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineStorage' model.""" + +import re +from abc import ABCMeta, abstractmethod +from collections.abc import Iterator +from typing import Any + +from graphrag.common.progress import ProgressReporter + + +class PipelineStorage(metaclass=ABCMeta): + """Provide a storage interface for the pipeline. This is where the pipeline will store its output data.""" + + @abstractmethod + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + progress: ProgressReporter | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find files in the storage using a file pattern, as well as a custom filter function.""" + + @abstractmethod + async def get( + self, key: str, as_bytes: bool | None = None, encoding: str | None = None + ) -> Any: + """Get the value for the given key. + + Args: + - key - The key to get the value for. + - as_bytes - Whether or not to return the value as bytes. + + Returns + ------- + - output - The value for the given key. + """ + + @abstractmethod + async def set( + self, key: str, value: str | bytes | None, encoding: str | None = None + ) -> None: + """Set the value for the given key. + + Args: + - key - The key to set the value for. + - value - The value to set. + """ + + @abstractmethod + async def has(self, key: str) -> bool: + """Return True if the given key exists in the storage. + + Args: + - key - The key to check for. + + Returns + ------- + - output - True if the key exists in the storage, False otherwise. + """ + + @abstractmethod + async def delete(self, key: str) -> None: + """Delete the given key from the storage. + + Args: + - key - The key to delete. + """ + + @abstractmethod + async def clear(self) -> None: + """Clear the storage.""" + + @abstractmethod + def child(self, name: str | None) -> "PipelineStorage": + """Create a child storage instance.""" diff --git a/func-app/graphrag/common/utils/common_utils.py b/func-app/graphrag/common/utils/common_utils.py new file mode 100644 index 0000000000..d53345d245 --- /dev/null +++ b/func-app/graphrag/common/utils/common_utils.py @@ -0,0 +1,11 @@ +import uuid + +def is_valid_guid(guid_str): + """Utility to check valid Guid.""" + try: + # Attempt to create a UUID object + uuid_obj = uuid.UUID(guid_str, version=4) + # Check if the string representation matches the UUID object + return str(uuid_obj) == guid_str + except ValueError: + return False \ No newline at end of file diff --git a/func-app/graphrag/common/utils/context_utils.py b/func-app/graphrag/common/utils/context_utils.py new file mode 100644 index 0000000000..687779e9c1 --- /dev/null +++ b/func-app/graphrag/common/utils/context_utils.py @@ -0,0 +1,9 @@ +from graphrag.config import ( + GraphRagConfig, +) + +def get_files_by_contextid(config: GraphRagConfig, context_id: str): + """Utility function to get files by context id""" + # General: eventually this will be comming from cosmos db or any other storage + filesInContext = config.query_context.files + return filesInContext \ No newline at end of file diff --git a/func-app/graphrag/config/__init__.py b/func-app/graphrag/config/__init__.py new file mode 100644 index 0000000000..5870c4ae71 --- /dev/null +++ b/func-app/graphrag/config/__init__.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine default config package root.""" + +from .create_graphrag_config import ( + create_graphrag_config, +) +from .enums import ( + CacheType, + ContextSwitchType, + InputFileType, + InputType, + LLMType, + ReportingType, + StorageType, + TextEmbeddingTarget, +) +from .errors import ( + ApiKeyMissingError, + AzureApiBaseMissingError, + AzureDeploymentNameMissingError, +) +from .input_models import ( + CacheConfigInput, + ChunkingConfigInput, + ClaimExtractionConfigInput, + ClusterGraphConfigInput, + CommunityReportsConfigInput, + EmbedGraphConfigInput, + EntityExtractionConfigInput, + GlobalSearchConfigInput, + GraphRagConfigInput, + InputConfigInput, + LLMConfigInput, + LLMParametersInput, + LocalSearchConfigInput, + ParallelizationParametersInput, + ReportingConfigInput, + SnapshotsConfigInput, + StorageConfigInput, + SummarizeDescriptionsConfigInput, + TextEmbeddingConfigInput, + UmapConfigInput, +) +from .models import ( + CacheConfig, + ChunkingConfig, + ClaimExtractionConfig, + ClusterGraphConfig, + CommunityReportsConfig, + EmbedGraphConfig, + EntityExtractionConfig, + GlobalSearchConfig, + GraphRagConfig, + InputConfig, + LLMConfig, + LLMParameters, + LocalSearchConfig, + ParallelizationParameters, + QueryContextConfig, + ReportingConfig, + SnapshotsConfig, + StorageConfig, + SummarizeDescriptionsConfig, + TextEmbeddingConfig, + UmapConfig, +) +from .read_dotenv import read_dotenv + +__all__ = [ + "ApiKeyMissingError", + "AzureApiBaseMissingError", + "AzureDeploymentNameMissingError", + "CacheConfig", + "ContextSwitchType", + "CacheConfigInput", + "CacheType", + "ChunkingConfig", + "ChunkingConfigInput", + "ClaimExtractionConfig", + "ClaimExtractionConfigInput", + "ClusterGraphConfig", + "ClusterGraphConfigInput", + "CommunityReportsConfig", + "CommunityReportsConfigInput", + "EmbedGraphConfig", + "EmbedGraphConfigInput", + "EntityExtractionConfig", + "EntityExtractionConfigInput", + "GlobalSearchConfig", + "GlobalSearchConfigInput", + "GraphRagConfig", + "GraphRagConfigInput", + "InputConfig", + "InputConfigInput", + "InputFileType", + "InputType", + "LLMConfig", + "LLMConfigInput", + "LLMParameters", + "LLMParametersInput", + "LLMType", + "LocalSearchConfig", + "LocalSearchConfigInput", + "ParallelizationParameters", + "ParallelizationParametersInput", + "QueryContextConfig", + "QueryContextConfigInput", + "ReportingConfig", + "ReportingConfigInput", + "ReportingType", + "SnapshotsConfig", + "SnapshotsConfigInput", + "StorageConfig", + "StorageConfigInput", + "StorageType", + "StorageType", + "SummarizeDescriptionsConfig", + "SummarizeDescriptionsConfigInput", + "TextEmbeddingConfig", + "TextEmbeddingConfigInput", + "TextEmbeddingTarget", + "UmapConfig", + "UmapConfigInput", + "create_graphrag_config", + "read_dotenv", +] diff --git a/func-app/graphrag/config/create_graphrag_config.py b/func-app/graphrag/config/create_graphrag_config.py new file mode 100644 index 0000000000..34b953345e --- /dev/null +++ b/func-app/graphrag/config/create_graphrag_config.py @@ -0,0 +1,687 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration, loaded from environment variables.""" + +import os +from enum import Enum +from pathlib import Path +from typing import cast + +from datashaper import AsyncType +from environs import Env +from pydantic import TypeAdapter + +import graphrag.config.defaults as defs + +from .enums import ( + CacheType, + InputFileType, + InputType, + LLMType, + ReportingType, + StorageType, + TextEmbeddingTarget, +) +from .environment_reader import EnvironmentReader +from .errors import ( + ApiKeyMissingError, + AzureApiBaseMissingError, + AzureDeploymentNameMissingError, +) +from .input_models import ( + GraphRagConfigInput, + LLMConfigInput, +) +from .models import ( + CacheConfig, + ChunkingConfig, + ClaimExtractionConfig, + ClusterGraphConfig, + CommunityReportsConfig, + EmbedGraphConfig, + EntityExtractionConfig, + GlobalSearchConfig, + GraphRagConfig, + InputConfig, + LLMParameters, + LocalSearchConfig, + ParallelizationParameters, + QueryContextConfig, + ReportingConfig, + SnapshotsConfig, + StorageConfig, + SummarizeDescriptionsConfig, + TextEmbeddingConfig, + UmapConfig, + GraphDBConfig, +) +from .read_dotenv import read_dotenv + +InputModelValidator = TypeAdapter(GraphRagConfigInput) + + +def create_graphrag_config( + values: GraphRagConfigInput | None = None, root_dir: str | None = None +) -> GraphRagConfig: + """Load Configuration Parameters from a dictionary.""" + values = values or {} + root_dir = root_dir or str(Path.cwd()) + env = _make_env(root_dir) + _token_replace(cast(dict, values)) + InputModelValidator.validate_python(values, strict=True) + + reader = EnvironmentReader(env) + + def hydrate_async_type(input: LLMConfigInput, base: AsyncType) -> AsyncType: + value = input.get(Fragment.async_mode) + return AsyncType(value) if value else base + + def hydrate_llm_params( + config: LLMConfigInput, base: LLMParameters + ) -> LLMParameters: + with reader.use(config.get("llm")): + llm_type = reader.str(Fragment.type) + llm_type = LLMType(llm_type) if llm_type else base.type + api_key = reader.str(Fragment.api_key) or base.api_key + api_base = reader.str(Fragment.api_base) or base.api_base + cognitive_services_endpoint = ( + reader.str(Fragment.cognitive_services_endpoint) + or base.cognitive_services_endpoint + ) + deployment_name = ( + reader.str(Fragment.deployment_name) or base.deployment_name + ) + + if api_key is None and not _is_azure(llm_type): + raise ApiKeyMissingError + if _is_azure(llm_type): + if api_base is None: + raise AzureApiBaseMissingError + if deployment_name is None: + raise AzureDeploymentNameMissingError + + sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation) + if sleep_on_rate_limit is None: + sleep_on_rate_limit = base.sleep_on_rate_limit_recommendation + + return LLMParameters( + api_key=api_key, + type=llm_type, + api_base=api_base, + api_version=reader.str(Fragment.api_version) or base.api_version, + organization=reader.str("organization") or base.organization, + proxy=reader.str("proxy") or base.proxy, + model=reader.str("model") or base.model, + max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens, + temperature=reader.float(Fragment.temperature) or base.temperature, + top_p=reader.float(Fragment.top_p) or base.top_p, + n=reader.int(Fragment.n) or base.n, + model_supports_json=reader.bool(Fragment.model_supports_json) + or base.model_supports_json, + request_timeout=reader.float(Fragment.request_timeout) + or base.request_timeout, + cognitive_services_endpoint=cognitive_services_endpoint, + deployment_name=deployment_name, + tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm) + or base.tokens_per_minute, + requests_per_minute=reader.int("requests_per_minute", Fragment.rpm) + or base.requests_per_minute, + max_retries=reader.int(Fragment.max_retries) or base.max_retries, + max_retry_wait=reader.float(Fragment.max_retry_wait) + or base.max_retry_wait, + sleep_on_rate_limit_recommendation=sleep_on_rate_limit, + concurrent_requests=reader.int(Fragment.concurrent_requests) + or base.concurrent_requests, + ) + + def hydrate_embeddings_params( + config: LLMConfigInput, base: LLMParameters + ) -> LLMParameters: + with reader.use(config.get("llm")): + api_type = reader.str(Fragment.type) or defs.EMBEDDING_TYPE + api_type = LLMType(api_type) if api_type else defs.LLM_TYPE + api_key = reader.str(Fragment.api_key) or base.api_key + + # In a unique events where: + # - same api_bases for LLM and embeddings (both Azure) + # - different api_bases for LLM and embeddings (both Azure) + # - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important) + # - LLM uses Azure OpenAI, while embeddings uses third-party OpenAI-like API + api_base = ( + reader.str(Fragment.api_base) or base.api_base + if _is_azure(api_type) + else reader.str(Fragment.api_base) + ) + api_version = ( + reader.str(Fragment.api_version) or base.api_version + if _is_azure(api_type) + else reader.str(Fragment.api_version) + ) + api_organization = reader.str("organization") or base.organization + api_proxy = reader.str("proxy") or base.proxy + cognitive_services_endpoint = ( + reader.str(Fragment.cognitive_services_endpoint) + or base.cognitive_services_endpoint + ) + deployment_name = reader.str(Fragment.deployment_name) + + if api_key is None and not _is_azure(api_type): + raise ApiKeyMissingError(embedding=True) + if _is_azure(api_type): + if api_base is None: + raise AzureApiBaseMissingError(embedding=True) + if deployment_name is None: + raise AzureDeploymentNameMissingError(embedding=True) + + sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation) + if sleep_on_rate_limit is None: + sleep_on_rate_limit = base.sleep_on_rate_limit_recommendation + + return LLMParameters( + api_key=api_key, + type=api_type, + api_base=api_base, + api_version=api_version, + organization=api_organization, + proxy=api_proxy, + model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL, + request_timeout=reader.float(Fragment.request_timeout) + or defs.LLM_REQUEST_TIMEOUT, + cognitive_services_endpoint=cognitive_services_endpoint, + deployment_name=deployment_name, + tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm) + or defs.LLM_TOKENS_PER_MINUTE, + requests_per_minute=reader.int("requests_per_minute", Fragment.rpm) + or defs.LLM_REQUESTS_PER_MINUTE, + max_retries=reader.int(Fragment.max_retries) or defs.LLM_MAX_RETRIES, + max_retry_wait=reader.float(Fragment.max_retry_wait) + or defs.LLM_MAX_RETRY_WAIT, + sleep_on_rate_limit_recommendation=sleep_on_rate_limit, + concurrent_requests=reader.int(Fragment.concurrent_requests) + or defs.LLM_CONCURRENT_REQUESTS, + ) + + def hydrate_parallelization_params( + config: LLMConfigInput, base: ParallelizationParameters + ) -> ParallelizationParameters: + with reader.use(config.get("parallelization")): + return ParallelizationParameters( + num_threads=reader.int("num_threads", Fragment.thread_count) + or base.num_threads, + stagger=reader.float("stagger", Fragment.thread_stagger) + or base.stagger, + ) + + fallback_oai_key = env("OPENAI_API_KEY", env("AZURE_OPENAI_API_KEY", None)) + fallback_oai_org = env("OPENAI_ORG_ID", None) + fallback_oai_base = env("OPENAI_BASE_URL", None) + fallback_oai_version = env("OPENAI_API_VERSION", None) + + with reader.envvar_prefix(Section.graphrag), reader.use(values): + async_mode = reader.str(Fragment.async_mode) + async_mode = AsyncType(async_mode) if async_mode else defs.ASYNC_MODE + + fallback_oai_key = reader.str(Fragment.api_key) or fallback_oai_key + fallback_oai_org = reader.str(Fragment.api_organization) or fallback_oai_org + fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base + fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version + fallback_oai_proxy = reader.str(Fragment.api_proxy) + + with reader.envvar_prefix(Section.llm): + with reader.use(values.get("llm")): + llm_type = reader.str(Fragment.type) + llm_type = LLMType(llm_type) if llm_type else defs.LLM_TYPE + api_key = reader.str(Fragment.api_key) or fallback_oai_key + api_organization = ( + reader.str(Fragment.api_organization) or fallback_oai_org + ) + api_base = reader.str(Fragment.api_base) or fallback_oai_base + api_version = reader.str(Fragment.api_version) or fallback_oai_version + api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy + cognitive_services_endpoint = reader.str( + Fragment.cognitive_services_endpoint + ) + deployment_name = reader.str(Fragment.deployment_name) + + if api_key is None and not _is_azure(llm_type): + raise ApiKeyMissingError + if _is_azure(llm_type): + if api_base is None: + raise AzureApiBaseMissingError + if deployment_name is None: + raise AzureDeploymentNameMissingError + + sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation) + if sleep_on_rate_limit is None: + sleep_on_rate_limit = defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION + + llm_model = LLMParameters( + api_key=api_key, + api_base=api_base, + api_version=api_version, + organization=api_organization, + proxy=api_proxy, + type=llm_type, + model=reader.str(Fragment.model) or defs.LLM_MODEL, + max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS, + temperature=reader.float(Fragment.temperature) + or defs.LLM_TEMPERATURE, + top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P, + n=reader.int(Fragment.n) or defs.LLM_N, + model_supports_json=reader.bool(Fragment.model_supports_json), + request_timeout=reader.float(Fragment.request_timeout) + or defs.LLM_REQUEST_TIMEOUT, + cognitive_services_endpoint=cognitive_services_endpoint, + deployment_name=deployment_name, + tokens_per_minute=reader.int(Fragment.tpm) + or defs.LLM_TOKENS_PER_MINUTE, + requests_per_minute=reader.int(Fragment.rpm) + or defs.LLM_REQUESTS_PER_MINUTE, + max_retries=reader.int(Fragment.max_retries) + or defs.LLM_MAX_RETRIES, + max_retry_wait=reader.float(Fragment.max_retry_wait) + or defs.LLM_MAX_RETRY_WAIT, + sleep_on_rate_limit_recommendation=sleep_on_rate_limit, + concurrent_requests=reader.int(Fragment.concurrent_requests) + or defs.LLM_CONCURRENT_REQUESTS, + ) + with reader.use(values.get("parallelization")): + llm_parallelization_model = ParallelizationParameters( + stagger=reader.float("stagger", Fragment.thread_stagger) + or defs.PARALLELIZATION_STAGGER, + num_threads=reader.int("num_threads", Fragment.thread_count) + or defs.PARALLELIZATION_NUM_THREADS, + ) + embeddings_config = values.get("embeddings") or {} + with reader.envvar_prefix(Section.embedding), reader.use(embeddings_config): + embeddings_target = reader.str("target") + embeddings_model = TextEmbeddingConfig( + llm=hydrate_embeddings_params(embeddings_config, llm_model), + parallelization=hydrate_parallelization_params( + embeddings_config, llm_parallelization_model + ), + vector_store=embeddings_config.get("vector_store", None), + async_mode=hydrate_async_type(embeddings_config, async_mode), + target=( + TextEmbeddingTarget(embeddings_target) + if embeddings_target + else defs.EMBEDDING_TARGET + ), + batch_size=reader.int("batch_size") or defs.EMBEDDING_BATCH_SIZE, + batch_max_tokens=reader.int("batch_max_tokens") + or defs.EMBEDDING_BATCH_MAX_TOKENS, + skip=reader.list("skip") or [], + ) + with ( + reader.envvar_prefix(Section.node2vec), + reader.use(values.get("embed_graph")), + ): + embed_graph_model = EmbedGraphConfig( + enabled=reader.bool(Fragment.enabled) or defs.NODE2VEC_ENABLED, + num_walks=reader.int("num_walks") or defs.NODE2VEC_NUM_WALKS, + walk_length=reader.int("walk_length") or defs.NODE2VEC_WALK_LENGTH, + window_size=reader.int("window_size") or defs.NODE2VEC_WINDOW_SIZE, + iterations=reader.int("iterations") or defs.NODE2VEC_ITERATIONS, + random_seed=reader.int("random_seed") or defs.NODE2VEC_RANDOM_SEED, + ) + with reader.envvar_prefix(Section.input), reader.use(values.get("input")): + input_type = reader.str("type") + file_type = reader.str(Fragment.file_type) + input_model = InputConfig( + file_type=( + InputFileType(file_type) if file_type else defs.INPUT_FILE_TYPE + ), + type=(InputType(input_type) if input_type else defs.INPUT_TYPE), + encoding=reader.str("file_encoding", Fragment.encoding) + or defs.INPUT_FILE_ENCODING, + base_dir=reader.str(Fragment.base_dir) or defs.INPUT_BASE_DIR, + file_pattern=reader.str("file_pattern") + or ( + defs.INPUT_TEXT_PATTERN + if file_type == InputFileType.text + else defs.INPUT_CSV_PATTERN + ), + source_column=reader.str("source_column"), + timestamp_column=reader.str("timestamp_column"), + timestamp_format=reader.str("timestamp_format"), + text_column=reader.str("text_column") or defs.INPUT_TEXT_COLUMN, + title_column=reader.str("title_column"), + document_attribute_columns=reader.list("document_attribute_columns") + or [], + connection_string=reader.str(Fragment.conn_string), + storage_account_blob_url=reader.str(Fragment.storage_account_blob_url), + container_name=reader.str(Fragment.container_name), + ) + with reader.envvar_prefix(Section.cache), reader.use(values.get("cache")): + c_type = reader.str(Fragment.type) + cache_model = CacheConfig( + type=CacheType(c_type) if c_type else defs.CACHE_TYPE, + connection_string=reader.str(Fragment.conn_string), + storage_account_blob_url=reader.str(Fragment.storage_account_blob_url), + container_name=reader.str(Fragment.container_name), + base_dir=reader.str(Fragment.base_dir) or defs.CACHE_BASE_DIR, + ) + with ( + reader.envvar_prefix(Section.reporting), + reader.use(values.get("reporting")), + ): + r_type = reader.str(Fragment.type) + reporting_model = ReportingConfig( + type=ReportingType(r_type) if r_type else defs.REPORTING_TYPE, + connection_string=reader.str(Fragment.conn_string), + storage_account_blob_url=reader.str(Fragment.storage_account_blob_url), + container_name=reader.str(Fragment.container_name), + base_dir=reader.str(Fragment.base_dir) or defs.REPORTING_BASE_DIR, + ) + with reader.envvar_prefix(Section.storage), reader.use(values.get("storage")): + s_type = reader.str(Fragment.type) + storage_model = StorageConfig( + type=StorageType(s_type) if s_type else defs.STORAGE_TYPE, + connection_string=reader.str(Fragment.conn_string), + storage_account_blob_url=reader.str(Fragment.storage_account_blob_url), + container_name=reader.str(Fragment.container_name), + base_dir=reader.str(Fragment.base_dir) or defs.STORAGE_BASE_DIR, + overwrite=reader.bool(Fragment.overwrite) or False + ) + with reader.envvar_prefix(Section.chunk), reader.use(values.get("chunks")): + group_by_columns = reader.list("group_by_columns", "BY_COLUMNS") + if group_by_columns is None: + group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS + + chunks_model = ChunkingConfig( + size=reader.int("size") or defs.CHUNK_SIZE, + overlap=reader.int("overlap") or defs.CHUNK_OVERLAP, + group_by_columns=group_by_columns, + encoding_model=reader.str(Fragment.encoding_model), + ) + with ( + reader.envvar_prefix(Section.snapshot), + reader.use(values.get("snapshots")), + ): + snapshots_model = SnapshotsConfig( + graphml=reader.bool("graphml") or defs.SNAPSHOTS_GRAPHML, + raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES, + top_level_nodes=reader.bool("top_level_nodes") + or defs.SNAPSHOTS_TOP_LEVEL_NODES, + ) + with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")): + umap_model = UmapConfig( + enabled=reader.bool(Fragment.enabled) or defs.UMAP_ENABLED, + ) + + entity_extraction_config = values.get("entity_extraction") or {} + with ( + reader.envvar_prefix(Section.entity_extraction), + reader.use(entity_extraction_config), + ): + max_gleanings = reader.int(Fragment.max_gleanings) + max_gleanings = ( + max_gleanings + if max_gleanings is not None + else defs.ENTITY_EXTRACTION_MAX_GLEANINGS + ) + + entity_extraction_model = EntityExtractionConfig( + llm=hydrate_llm_params(entity_extraction_config, llm_model), + parallelization=hydrate_parallelization_params( + entity_extraction_config, llm_parallelization_model + ), + async_mode=hydrate_async_type(entity_extraction_config, async_mode), + entity_types=reader.list("entity_types") + or defs.ENTITY_EXTRACTION_ENTITY_TYPES, + max_gleanings=max_gleanings, + prompt=reader.str("prompt", Fragment.prompt_file), + encoding_model=reader.str(Fragment.encoding_model), + ) + + claim_extraction_config = values.get("claim_extraction") or {} + with ( + reader.envvar_prefix(Section.claim_extraction), + reader.use(claim_extraction_config), + ): + max_gleanings = reader.int(Fragment.max_gleanings) + max_gleanings = ( + max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS + ) + claim_extraction_model = ClaimExtractionConfig( + enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED, + llm=hydrate_llm_params(claim_extraction_config, llm_model), + parallelization=hydrate_parallelization_params( + claim_extraction_config, llm_parallelization_model + ), + async_mode=hydrate_async_type(claim_extraction_config, async_mode), + description=reader.str("description") or defs.CLAIM_DESCRIPTION, + prompt=reader.str("prompt", Fragment.prompt_file), + max_gleanings=max_gleanings, + encoding_model=reader.str(Fragment.encoding_model), + ) + + community_report_config = values.get("community_reports") or {} + with ( + reader.envvar_prefix(Section.community_reports), + reader.use(community_report_config), + ): + community_reports_model = CommunityReportsConfig( + llm=hydrate_llm_params(community_report_config, llm_model), + parallelization=hydrate_parallelization_params( + community_report_config, llm_parallelization_model + ), + async_mode=hydrate_async_type(community_report_config, async_mode), + prompt=reader.str("prompt", Fragment.prompt_file), + max_length=reader.int(Fragment.max_length) + or defs.COMMUNITY_REPORT_MAX_LENGTH, + max_input_length=reader.int("max_input_length") + or defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH, + ) + + summarize_description_config = values.get("summarize_descriptions") or {} + with ( + reader.envvar_prefix(Section.summarize_descriptions), + reader.use(values.get("summarize_descriptions")), + ): + summarize_descriptions_model = SummarizeDescriptionsConfig( + llm=hydrate_llm_params(summarize_description_config, llm_model), + parallelization=hydrate_parallelization_params( + summarize_description_config, llm_parallelization_model + ), + async_mode=hydrate_async_type(summarize_description_config, async_mode), + prompt=reader.str("prompt", Fragment.prompt_file), + max_length=reader.int(Fragment.max_length) + or defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH, + ) + + with reader.use(values.get("cluster_graph")): + cluster_graph_model = ClusterGraphConfig( + max_cluster_size=reader.int("max_cluster_size") or defs.MAX_CLUSTER_SIZE + ) + + with ( + reader.use(values.get("local_search")), + reader.envvar_prefix(Section.local_search), + ): + local_search_model = LocalSearchConfig( + text_unit_prop=reader.float("text_unit_prop") + or defs.LOCAL_SEARCH_TEXT_UNIT_PROP, + community_prop=reader.float("community_prop") + or defs.LOCAL_SEARCH_COMMUNITY_PROP, + conversation_history_max_turns=reader.int( + "conversation_history_max_turns" + ) + or defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS, + top_k_entities=reader.int("top_k_entities") + or defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES, + top_k_relationships=reader.int("top_k_relationships") + or defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS, + temperature=reader.float("llm_temperature") + or defs.LOCAL_SEARCH_LLM_TEMPERATURE, + top_p=reader.float("llm_top_p") or defs.LOCAL_SEARCH_LLM_TOP_P, + n=reader.int("llm_n") or defs.LOCAL_SEARCH_LLM_N, + max_tokens=reader.int(Fragment.max_tokens) + or defs.LOCAL_SEARCH_MAX_TOKENS, + llm_max_tokens=reader.int("llm_max_tokens") + or defs.LOCAL_SEARCH_LLM_MAX_TOKENS, + ) + + with ( + reader.use(values.get("global_search")), + reader.envvar_prefix(Section.global_search), + ): + global_search_model = GlobalSearchConfig( + temperature=reader.float("llm_temperature") + or defs.GLOBAL_SEARCH_LLM_TEMPERATURE, + top_p=reader.float("llm_top_p") or defs.GLOBAL_SEARCH_LLM_TOP_P, + n=reader.int("llm_n") or defs.GLOBAL_SEARCH_LLM_N, + max_tokens=reader.int(Fragment.max_tokens) + or defs.GLOBAL_SEARCH_MAX_TOKENS, + data_max_tokens=reader.int("data_max_tokens") + or defs.GLOBAL_SEARCH_DATA_MAX_TOKENS, + map_max_tokens=reader.int("map_max_tokens") + or defs.GLOBAL_SEARCH_MAP_MAX_TOKENS, + reduce_max_tokens=reader.int("reduce_max_tokens") + or defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS, + concurrency=reader.int("concurrency") or defs.GLOBAL_SEARCH_CONCURRENCY, + ) + + with ( + reader.use(values.get("query_context")), + reader.envvar_prefix(Section.query_context), + ): + query_context_model = QueryContextConfig( + files=reader.list("files") or [], + ) + + with ( + reader.use(values.get("graphdb")), + reader.envvar_prefix(Section.query_context), + ): + graphdb_model = GraphDBConfig( + account_name=reader.str("account_name") or None, + account_key=reader.str("account_key") or None, + username=reader.str("username") or None, + enabled=reader.bool("enabled") or False, + cosmos_url=reader.str("cosmos_url") or None, + gremlin_url=reader.str("gremlin_url") or None, + ) + + encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL + skip_workflows = reader.list("skip_workflows") or [] + + return GraphRagConfig( + root_dir=root_dir, + llm=llm_model, + parallelization=llm_parallelization_model, + async_mode=async_mode, + embeddings=embeddings_model, + embed_graph=embed_graph_model, + reporting=reporting_model, + storage=storage_model, + cache=cache_model, + input=input_model, + chunks=chunks_model, + snapshots=snapshots_model, + entity_extraction=entity_extraction_model, + claim_extraction=claim_extraction_model, + community_reports=community_reports_model, + summarize_descriptions=summarize_descriptions_model, + umap=umap_model, + cluster_graph=cluster_graph_model, + encoding_model=encoding_model, + skip_workflows=skip_workflows, + local_search=local_search_model, + global_search=global_search_model, + query_context=query_context_model, + graphdb=graphdb_model, + ) + + +class Fragment(str, Enum): + """Configuration Fragments.""" + + api_base = "API_BASE" + api_key = "API_KEY" + api_version = "API_VERSION" + api_organization = "API_ORGANIZATION" + api_proxy = "API_PROXY" + async_mode = "ASYNC_MODE" + base_dir = "BASE_DIR" + overwrite = "Overwrite" + cognitive_services_endpoint = "COGNITIVE_SERVICES_ENDPOINT" + concurrent_requests = "CONCURRENT_REQUESTS" + conn_string = "CONNECTION_STRING" + container_name = "CONTAINER_NAME" + deployment_name = "DEPLOYMENT_NAME" + description = "DESCRIPTION" + enabled = "ENABLED" + encoding = "ENCODING" + encoding_model = "ENCODING_MODEL" + file_type = "FILE_TYPE" + max_gleanings = "MAX_GLEANINGS" + max_length = "MAX_LENGTH" + max_retries = "MAX_RETRIES" + max_retry_wait = "MAX_RETRY_WAIT" + max_tokens = "MAX_TOKENS" + temperature = "TEMPERATURE" + top_p = "TOP_P" + n = "N" + model = "MODEL" + model_supports_json = "MODEL_SUPPORTS_JSON" + prompt_file = "PROMPT_FILE" + request_timeout = "REQUEST_TIMEOUT" + rpm = "REQUESTS_PER_MINUTE" + sleep_recommendation = "SLEEP_ON_RATE_LIMIT_RECOMMENDATION" + storage_account_blob_url = "STORAGE_ACCOUNT_BLOB_URL" + thread_count = "THREAD_COUNT" + thread_stagger = "THREAD_STAGGER" + tpm = "TOKENS_PER_MINUTE" + type = "TYPE" + output = "OUTPUT" + + +class Section(str, Enum): + """Configuration Sections.""" + + base = "BASE" + cache = "CACHE" + chunk = "CHUNK" + claim_extraction = "CLAIM_EXTRACTION" + community_reports = "COMMUNITY_REPORTS" + embedding = "EMBEDDING" + entity_extraction = "ENTITY_EXTRACTION" + graphrag = "GRAPHRAG" + input = "INPUT" + llm = "LLM" + node2vec = "NODE2VEC" + reporting = "REPORTING" + snapshot = "SNAPSHOT" + storage = "STORAGE" + summarize_descriptions = "SUMMARIZE_DESCRIPTIONS" + umap = "UMAP" + local_search = "LOCAL_SEARCH" + global_search = "GLOBAL_SEARCH" + query_context = "QUERY_CONTEXT" + graphdb = "GRAPHDB" + + +def _is_azure(llm_type: LLMType | None) -> bool: + return ( + llm_type == LLMType.AzureOpenAI + or llm_type == LLMType.AzureOpenAIChat + or llm_type == LLMType.AzureOpenAIEmbedding + ) + + +def _make_env(root_dir: str) -> Env: + read_dotenv(root_dir) + env = Env(expand_vars=True) + env.read_env() + return env + + +def _token_replace(data: dict): + """Replace env-var tokens in a dictionary object.""" + for key, value in data.items(): + if isinstance(value, dict): + _token_replace(value) + elif isinstance(value, str): + data[key] = os.path.expandvars(value) diff --git a/func-app/graphrag/config/defaults.py b/func-app/graphrag/config/defaults.py new file mode 100644 index 0000000000..4d6489140a --- /dev/null +++ b/func-app/graphrag/config/defaults.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Common default configuration values.""" + +from datashaper import AsyncType + +from .enums import ( + CacheType, + InputFileType, + InputType, + LLMType, + ReportingType, + StorageType, + TextEmbeddingTarget, +) + +ASYNC_MODE = AsyncType.Threaded +ENCODING_MODEL = "cl100k_base" +# +# LLM Parameters +# +LLM_TYPE = LLMType.OpenAIChat +LLM_MODEL = "gpt-4-turbo-preview" +LLM_MAX_TOKENS = 4000 +LLM_TEMPERATURE = 0 +LLM_TOP_P = 1 +LLM_N = 1 +LLM_REQUEST_TIMEOUT = 180.0 +LLM_TOKENS_PER_MINUTE = 0 +LLM_REQUESTS_PER_MINUTE = 0 +LLM_MAX_RETRIES = 10 +LLM_MAX_RETRY_WAIT = 10.0 +LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION = True +LLM_CONCURRENT_REQUESTS = 25 + +# +# Text Embedding Parameters +# +EMBEDDING_TYPE = LLMType.OpenAIEmbedding +EMBEDDING_MODEL = "text-embedding-3-small" +EMBEDDING_BATCH_SIZE = 16 +EMBEDDING_BATCH_MAX_TOKENS = 8191 +EMBEDDING_TARGET = TextEmbeddingTarget.required + +CACHE_TYPE = CacheType.file +CACHE_BASE_DIR = "cache" +CHUNK_SIZE = 1200 +CHUNK_OVERLAP = 100 +CHUNK_GROUP_BY_COLUMNS = ["id"] +CLAIM_DESCRIPTION = ( + "Any claims or facts that could be relevant to information discovery." +) +CLAIM_MAX_GLEANINGS = 1 +CLAIM_EXTRACTION_ENABLED = False +MAX_CLUSTER_SIZE = 10 +COMMUNITY_REPORT_MAX_LENGTH = 2000 +COMMUNITY_REPORT_MAX_INPUT_LENGTH = 8000 +ENTITY_EXTRACTION_ENTITY_TYPES = ["organization", "person", "geo", "event"] +ENTITY_EXTRACTION_MAX_GLEANINGS = 1 +INPUT_FILE_TYPE = InputFileType.text +INPUT_TYPE = InputType.file +INPUT_BASE_DIR = "input" +INPUT_FILE_ENCODING = "utf-8" +INPUT_TEXT_COLUMN = "text" +INPUT_CSV_PATTERN = ".*\\.csv$" +INPUT_TEXT_PATTERN = ".*\\.txt$" +PARALLELIZATION_STAGGER = 0.3 +PARALLELIZATION_NUM_THREADS = 50 +NODE2VEC_ENABLED = False +NODE2VEC_NUM_WALKS = 10 +NODE2VEC_WALK_LENGTH = 40 +NODE2VEC_WINDOW_SIZE = 2 +NODE2VEC_ITERATIONS = 3 +NODE2VEC_RANDOM_SEED = 597832 +REPORTING_TYPE = ReportingType.file +REPORTING_BASE_DIR = "output/${timestamp}/reports" +SNAPSHOTS_GRAPHML = False +SNAPSHOTS_RAW_ENTITIES = False +SNAPSHOTS_TOP_LEVEL_NODES = False +STORAGE_BASE_DIR = "output/${timestamp}/artifacts" +STORAGE_TYPE = StorageType.file +SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500 +UMAP_ENABLED = False + +# Local Search +LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5 +LOCAL_SEARCH_COMMUNITY_PROP = 0.1 +LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS = 5 +LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10 +LOCAL_SEARCH_TOP_K_RELATIONSHIPS = 10 +LOCAL_SEARCH_MAX_TOKENS = 12_000 +LOCAL_SEARCH_LLM_TEMPERATURE = 0 +LOCAL_SEARCH_LLM_TOP_P = 1 +LOCAL_SEARCH_LLM_N = 1 +LOCAL_SEARCH_LLM_MAX_TOKENS = 2000 + +# Global Search +GLOBAL_SEARCH_LLM_TEMPERATURE = 0 +GLOBAL_SEARCH_LLM_TOP_P = 1 +GLOBAL_SEARCH_LLM_N = 1 +GLOBAL_SEARCH_MAX_TOKENS = 12_000 +GLOBAL_SEARCH_DATA_MAX_TOKENS = 12_000 +GLOBAL_SEARCH_MAP_MAX_TOKENS = 1000 +GLOBAL_SEARCH_REDUCE_MAX_TOKENS = 2_000 +GLOBAL_SEARCH_CONCURRENCY = 32 diff --git a/func-app/graphrag/config/enums.py b/func-app/graphrag/config/enums.py new file mode 100644 index 0000000000..4745acc5f5 --- /dev/null +++ b/func-app/graphrag/config/enums.py @@ -0,0 +1,127 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig' and 'PipelineMemoryCacheConfig' models.""" + +from __future__ import annotations + +from enum import Enum + + +class CacheType(str, Enum): + """The cache configuration type for the pipeline.""" + + file = "file" + """The file cache configuration type.""" + memory = "memory" + """The memory cache configuration type.""" + none = "none" + """The none cache configuration type.""" + blob = "blob" + """The blob cache configuration type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class InputFileType(str, Enum): + """The input file type for the pipeline.""" + + csv = "csv" + """The CSV input type.""" + text = "text" + """The text input type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class InputType(str, Enum): + """The input type for the pipeline.""" + + file = "file" + """The file storage type.""" + blob = "blob" + """The blob storage type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class StorageType(str, Enum): + """The storage type for the pipeline.""" + + file = "file" + """The file storage type.""" + memory = "memory" + """The memory storage type.""" + blob = "blob" + """The blob storage type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class ReportingType(str, Enum): + """The reporting configuration type for the pipeline.""" + + file = "file" + """The file reporting configuration type.""" + console = "console" + """The console reporting configuration type.""" + blob = "blob" + """The blob reporting configuration type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class TextEmbeddingTarget(str, Enum): + """The target to use for text embeddings.""" + + all = "all" + required = "required" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class LLMType(str, Enum): + """LLMType enum class definition.""" + + # Embeddings + OpenAIEmbedding = "openai_embedding" + AzureOpenAIEmbedding = "azure_openai_embedding" + + # Raw Completion + OpenAI = "openai" + AzureOpenAI = "azure_openai" + + # Chat Completion + OpenAIChat = "openai_chat" + AzureOpenAIChat = "azure_openai_chat" + + # Debug + StaticResponse = "static_response" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + +class ContextSwitchType(str, Enum): + """context switcher type.""" + + #context switch types + Activate = "activate" + Deactivate= "deactivate" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + diff --git a/func-app/graphrag/config/environment_reader.py b/func-app/graphrag/config/environment_reader.py new file mode 100644 index 0000000000..258422666c --- /dev/null +++ b/func-app/graphrag/config/environment_reader.py @@ -0,0 +1,155 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A configuration reader utility class.""" + +from collections.abc import Callable +from contextlib import contextmanager +from enum import Enum +from typing import Any, TypeVar + +from environs import Env + +T = TypeVar("T") + +KeyValue = str | Enum +EnvKeySet = str | list[str] + + +def read_key(value: KeyValue) -> str: + """Read a key value.""" + if not isinstance(value, str): + return value.value.lower() + return value.lower() + + +class EnvironmentReader: + """A configuration reader utility class.""" + + _env: Env + _config_stack: list[dict] + + def __init__(self, env: Env): + self._env = env + self._config_stack = [] + + @property + def env(self): + """Get the environment object.""" + return self._env + + def _read_env( + self, env_key: str | list[str], default_value: T, read: Callable[[str, T], T] + ) -> T | None: + if isinstance(env_key, str): + env_key = [env_key] + + for k in env_key: + result = read(k.upper(), default_value) + if result is not default_value: + return result + + return default_value + + def envvar_prefix(self, prefix: KeyValue): + """Set the environment variable prefix.""" + prefix = read_key(prefix) + prefix = f"{prefix}_".upper() + return self._env.prefixed(prefix) + + def use(self, value: Any | None): + """Create a context manager to push the value into the config_stack.""" + + @contextmanager + def config_context(): + self._config_stack.append(value or {}) + try: + yield + finally: + self._config_stack.pop() + + return config_context() + + @property + def section(self) -> dict: + """Get the current section.""" + return self._config_stack[-1] if self._config_stack else {} + + def str( + self, + key: KeyValue, + env_key: EnvKeySet | None = None, + default_value: str | None = None, + ) -> str | None: + """Read a configuration value.""" + key = read_key(key) + if self.section and key in self.section: + return self.section[key] + + return self._read_env( + env_key or key, default_value, (lambda k, dv: self._env(k, dv)) + ) + + def int( + self, + key: KeyValue, + env_key: EnvKeySet | None = None, + default_value: int | None = None, + ) -> int | None: + """Read an integer configuration value.""" + key = read_key(key) + if self.section and key in self.section: + return int(self.section[key]) + return self._read_env( + env_key or key, default_value, lambda k, dv: self._env.int(k, dv) + ) + + def bool( + self, + key: KeyValue, + env_key: EnvKeySet | None = None, + default_value: bool | None = None, + ) -> bool | None: + """Read an integer configuration value.""" + key = read_key(key) + if self.section and key in self.section: + return bool(self.section[key]) + + return self._read_env( + env_key or key, default_value, lambda k, dv: self._env.bool(k, dv) + ) + + def float( + self, + key: KeyValue, + env_key: EnvKeySet | None = None, + default_value: float | None = None, + ) -> float | None: + """Read a float configuration value.""" + key = read_key(key) + if self.section and key in self.section: + return float(self.section[key]) + return self._read_env( + env_key or key, default_value, lambda k, dv: self._env.float(k, dv) + ) + + def list( + self, + key: KeyValue, + env_key: EnvKeySet | None = None, + default_value: list | None = None, + ) -> list | None: + """Parse an list configuration value.""" + key = read_key(key) + result = None + if self.section and key in self.section: + result = self.section[key] + if isinstance(result, list): + return result + + if result is None: + result = self.str(key, env_key) + if result is not None: + result = [s.strip() for s in result.split(",")] + return [s for s in result if s] + return default_value diff --git a/func-app/graphrag/config/errors.py b/func-app/graphrag/config/errors.py new file mode 100644 index 0000000000..9a2161b8af --- /dev/null +++ b/func-app/graphrag/config/errors.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Errors for the default configuration.""" + + +class ApiKeyMissingError(ValueError): + """LLM Key missing error.""" + + def __init__(self, embedding: bool = False) -> None: + """Init method definition.""" + api_type = "Embedding" if embedding else "Completion" + api_key = "GRAPHRAG_EMBEDDING_API_KEY" if embedding else "GRAPHRAG_LLM_API_KEY" + msg = f"API Key is required for {api_type} API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or {api_key} environment variable." + super().__init__(msg) + + +class AzureApiBaseMissingError(ValueError): + """Azure API Base missing error.""" + + def __init__(self, embedding: bool = False) -> None: + """Init method definition.""" + api_type = "Embedding" if embedding else "Completion" + api_base = "GRAPHRAG_EMBEDDING_API_BASE" if embedding else "GRAPHRAG_API_BASE" + msg = f"API Base is required for {api_type} API. Please set either the OPENAI_API_BASE, GRAPHRAG_API_BASE or {api_base} environment variable." + super().__init__(msg) + + +class AzureDeploymentNameMissingError(ValueError): + """Azure Deployment Name missing error.""" + + def __init__(self, embedding: bool = False) -> None: + """Init method definition.""" + api_type = "Embedding" if embedding else "Completion" + api_base = ( + "GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME" + if embedding + else "GRAPHRAG_LLM_DEPLOYMENT_NAME" + ) + msg = f"Deployment Name is required for {api_type} API. Please set either the OPENAI_DEPLOYMENT_NAME, GRAPHRAG_LLM_DEPLOYMENT_NAME or {api_base} environment variable." + super().__init__(msg) diff --git a/func-app/graphrag/config/input_models/__init__.py b/func-app/graphrag/config/input_models/__init__.py new file mode 100644 index 0000000000..f905ae38b2 --- /dev/null +++ b/func-app/graphrag/config/input_models/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Interfaces for Default Config parameterization.""" + +from .cache_config_input import CacheConfigInput +from .chunking_config_input import ChunkingConfigInput +from .claim_extraction_config_input import ClaimExtractionConfigInput +from .cluster_graph_config_input import ClusterGraphConfigInput +from .community_reports_config_input import CommunityReportsConfigInput +from .embed_graph_config_input import EmbedGraphConfigInput +from .entity_extraction_config_input import EntityExtractionConfigInput +from .global_search_config_input import GlobalSearchConfigInput +from .graphrag_config_input import GraphRagConfigInput +from .input_config_input import InputConfigInput +from .llm_config_input import LLMConfigInput +from .llm_parameters_input import LLMParametersInput +from .local_search_config_input import LocalSearchConfigInput +from .parallelization_parameters_input import ParallelizationParametersInput +from .reporting_config_input import ReportingConfigInput +from .snapshots_config_input import SnapshotsConfigInput +from .storage_config_input import StorageConfigInput +from .summarize_descriptions_config_input import ( + SummarizeDescriptionsConfigInput, +) +from .text_embedding_config_input import TextEmbeddingConfigInput +from .umap_config_input import UmapConfigInput + +__all__ = [ + "CacheConfigInput", + "ChunkingConfigInput", + "ClaimExtractionConfigInput", + "ClusterGraphConfigInput", + "CommunityReportsConfigInput", + "EmbedGraphConfigInput", + "EntityExtractionConfigInput", + "GlobalSearchConfigInput", + "GraphRagConfigInput", + "InputConfigInput", + "LLMConfigInput", + "LLMParametersInput", + "LocalSearchConfigInput", + "ParallelizationParametersInput", + "ReportingConfigInput", + "SnapshotsConfigInput", + "StorageConfigInput", + "SummarizeDescriptionsConfigInput", + "TextEmbeddingConfigInput", + "UmapConfigInput", +] diff --git a/func-app/graphrag/config/input_models/cache_config_input.py b/func-app/graphrag/config/input_models/cache_config_input.py new file mode 100644 index 0000000000..fe88d35b44 --- /dev/null +++ b/func-app/graphrag/config/input_models/cache_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + +from graphrag.config.enums import CacheType + + +class CacheConfigInput(TypedDict): + """The default configuration section for Cache.""" + + type: NotRequired[CacheType | str | None] + base_dir: NotRequired[str | None] + connection_string: NotRequired[str | None] + container_name: NotRequired[str | None] + storage_account_blob_url: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/chunking_config_input.py b/func-app/graphrag/config/input_models/chunking_config_input.py new file mode 100644 index 0000000000..bbf4fc735f --- /dev/null +++ b/func-app/graphrag/config/input_models/chunking_config_input.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class ChunkingConfigInput(TypedDict): + """Configuration section for chunking.""" + + size: NotRequired[int | str | None] + overlap: NotRequired[int | str | None] + group_by_columns: NotRequired[list[str] | str | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/claim_extraction_config_input.py b/func-app/graphrag/config/input_models/claim_extraction_config_input.py new file mode 100644 index 0000000000..f23e31d0a7 --- /dev/null +++ b/func-app/graphrag/config/input_models/claim_extraction_config_input.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from .llm_config_input import LLMConfigInput + + +class ClaimExtractionConfigInput(LLMConfigInput): + """Configuration section for claim extraction.""" + + enabled: NotRequired[bool | None] + prompt: NotRequired[str | None] + description: NotRequired[str | None] + max_gleanings: NotRequired[int | str | None] + strategy: NotRequired[dict | None] + encoding_model: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/cluster_graph_config_input.py b/func-app/graphrag/config/input_models/cluster_graph_config_input.py new file mode 100644 index 0000000000..eb6f9cd1c6 --- /dev/null +++ b/func-app/graphrag/config/input_models/cluster_graph_config_input.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class ClusterGraphConfigInput(TypedDict): + """Configuration section for clustering graphs.""" + + max_cluster_size: NotRequired[int | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/community_reports_config_input.py b/func-app/graphrag/config/input_models/community_reports_config_input.py new file mode 100644 index 0000000000..79ae3152e7 --- /dev/null +++ b/func-app/graphrag/config/input_models/community_reports_config_input.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from .llm_config_input import LLMConfigInput + + +class CommunityReportsConfigInput(LLMConfigInput): + """Configuration section for community reports.""" + + prompt: NotRequired[str | None] + max_length: NotRequired[int | str | None] + max_input_length: NotRequired[int | str | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/embed_graph_config_input.py b/func-app/graphrag/config/input_models/embed_graph_config_input.py new file mode 100644 index 0000000000..f8b6ee6faf --- /dev/null +++ b/func-app/graphrag/config/input_models/embed_graph_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class EmbedGraphConfigInput(TypedDict): + """The default configuration section for Node2Vec.""" + + enabled: NotRequired[bool | str | None] + num_walks: NotRequired[int | str | None] + walk_length: NotRequired[int | str | None] + window_size: NotRequired[int | str | None] + iterations: NotRequired[int | str | None] + random_seed: NotRequired[int | str | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/entity_extraction_config_input.py b/func-app/graphrag/config/input_models/entity_extraction_config_input.py new file mode 100644 index 0000000000..f1d3587e99 --- /dev/null +++ b/func-app/graphrag/config/input_models/entity_extraction_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from .llm_config_input import LLMConfigInput + + +class EntityExtractionConfigInput(LLMConfigInput): + """Configuration section for entity extraction.""" + + prompt: NotRequired[str | None] + entity_types: NotRequired[list[str] | str | None] + max_gleanings: NotRequired[int | str | None] + strategy: NotRequired[dict | None] + encoding_model: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/global_search_config_input.py b/func-app/graphrag/config/input_models/global_search_config_input.py new file mode 100644 index 0000000000..e13fbbfa9e --- /dev/null +++ b/func-app/graphrag/config/input_models/global_search_config_input.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class GlobalSearchConfigInput(TypedDict): + """The default configuration section for Cache.""" + + max_tokens: NotRequired[int | str | None] + data_max_tokens: NotRequired[int | str | None] + map_max_tokens: NotRequired[int | str | None] + reduce_max_tokens: NotRequired[int | str | None] + concurrency: NotRequired[int | str | None] diff --git a/func-app/graphrag/config/input_models/graphrag_config_input.py b/func-app/graphrag/config/input_models/graphrag_config_input.py new file mode 100644 index 0000000000..7c04dea2e3 --- /dev/null +++ b/func-app/graphrag/config/input_models/graphrag_config_input.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from .cache_config_input import CacheConfigInput +from .chunking_config_input import ChunkingConfigInput +from .claim_extraction_config_input import ClaimExtractionConfigInput +from .cluster_graph_config_input import ClusterGraphConfigInput +from .community_reports_config_input import CommunityReportsConfigInput +from .embed_graph_config_input import EmbedGraphConfigInput +from .entity_extraction_config_input import EntityExtractionConfigInput +from .global_search_config_input import GlobalSearchConfigInput +from .input_config_input import InputConfigInput +from .llm_config_input import LLMConfigInput +from .local_search_config_input import LocalSearchConfigInput +from .reporting_config_input import ReportingConfigInput +from .snapshots_config_input import SnapshotsConfigInput +from .storage_config_input import StorageConfigInput +from .summarize_descriptions_config_input import ( + SummarizeDescriptionsConfigInput, +) +from .text_embedding_config_input import TextEmbeddingConfigInput +from .umap_config_input import UmapConfigInput + + +class GraphRagConfigInput(LLMConfigInput): + """Base class for the Default-Configuration parameterization settings.""" + + reporting: NotRequired[ReportingConfigInput | None] + storage: NotRequired[StorageConfigInput | None] + cache: NotRequired[CacheConfigInput | None] + input: NotRequired[InputConfigInput | None] + embed_graph: NotRequired[EmbedGraphConfigInput | None] + embeddings: NotRequired[TextEmbeddingConfigInput | None] + chunks: NotRequired[ChunkingConfigInput | None] + snapshots: NotRequired[SnapshotsConfigInput | None] + entity_extraction: NotRequired[EntityExtractionConfigInput | None] + summarize_descriptions: NotRequired[SummarizeDescriptionsConfigInput | None] + community_reports: NotRequired[CommunityReportsConfigInput | None] + claim_extraction: NotRequired[ClaimExtractionConfigInput | None] + cluster_graph: NotRequired[ClusterGraphConfigInput | None] + umap: NotRequired[UmapConfigInput | None] + encoding_model: NotRequired[str | None] + skip_workflows: NotRequired[list[str] | str | None] + local_search: NotRequired[LocalSearchConfigInput | None] + global_search: NotRequired[GlobalSearchConfigInput | None] diff --git a/func-app/graphrag/config/input_models/input_config_input.py b/func-app/graphrag/config/input_models/input_config_input.py new file mode 100644 index 0000000000..4ff89d2c9a --- /dev/null +++ b/func-app/graphrag/config/input_models/input_config_input.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + +from graphrag.config.enums import InputFileType, InputType + + +class InputConfigInput(TypedDict): + """The default configuration section for Input.""" + + type: NotRequired[InputType | str | None] + file_type: NotRequired[InputFileType | str | None] + base_dir: NotRequired[str | None] + connection_string: NotRequired[str | None] + container_name: NotRequired[str | None] + file_encoding: NotRequired[str | None] + file_pattern: NotRequired[str | None] + source_column: NotRequired[str | None] + timestamp_column: NotRequired[str | None] + timestamp_format: NotRequired[str | None] + text_column: NotRequired[str | None] + title_column: NotRequired[str | None] + document_attribute_columns: NotRequired[list[str] | str | None] + storage_account_blob_url: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/llm_config_input.py b/func-app/graphrag/config/input_models/llm_config_input.py new file mode 100644 index 0000000000..67231371b8 --- /dev/null +++ b/func-app/graphrag/config/input_models/llm_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from datashaper import AsyncType +from typing_extensions import NotRequired, TypedDict + +from .llm_parameters_input import LLMParametersInput +from .parallelization_parameters_input import ParallelizationParametersInput + + +class LLMConfigInput(TypedDict): + """Base class for LLM-configured steps.""" + + llm: NotRequired[LLMParametersInput | None] + parallelization: NotRequired[ParallelizationParametersInput | None] + async_mode: NotRequired[AsyncType | str | None] diff --git a/func-app/graphrag/config/input_models/llm_parameters_input.py b/func-app/graphrag/config/input_models/llm_parameters_input.py new file mode 100644 index 0000000000..c89c6c0922 --- /dev/null +++ b/func-app/graphrag/config/input_models/llm_parameters_input.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Parameters model.""" + +from typing_extensions import NotRequired, TypedDict + +from graphrag.config.enums import LLMType + + +class LLMParametersInput(TypedDict): + """LLM Parameters model.""" + + api_key: NotRequired[str | None] + type: NotRequired[LLMType | str | None] + model: NotRequired[str | None] + max_tokens: NotRequired[int | str | None] + request_timeout: NotRequired[float | str | None] + api_base: NotRequired[str | None] + api_version: NotRequired[str | None] + organization: NotRequired[str | None] + proxy: NotRequired[str | None] + cognitive_services_endpoint: NotRequired[str | None] + deployment_name: NotRequired[str | None] + model_supports_json: NotRequired[bool | str | None] + tokens_per_minute: NotRequired[int | str | None] + requests_per_minute: NotRequired[int | str | None] + max_retries: NotRequired[int | str | None] + max_retry_wait: NotRequired[float | str | None] + sleep_on_rate_limit_recommendation: NotRequired[bool | str | None] + concurrent_requests: NotRequired[int | str | None] diff --git a/func-app/graphrag/config/input_models/local_search_config_input.py b/func-app/graphrag/config/input_models/local_search_config_input.py new file mode 100644 index 0000000000..23df40102a --- /dev/null +++ b/func-app/graphrag/config/input_models/local_search_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class LocalSearchConfigInput(TypedDict): + """The default configuration section for Cache.""" + + text_unit_prop: NotRequired[float | str | None] + community_prop: NotRequired[float | str | None] + conversation_history_max_turns: NotRequired[int | str | None] + top_k_entities: NotRequired[int | str | None] + top_k_relationships: NotRequired[int | str | None] + max_tokens: NotRequired[int | str | None] + llm_max_tokens: NotRequired[int | str | None] diff --git a/func-app/graphrag/config/input_models/parallelization_parameters_input.py b/func-app/graphrag/config/input_models/parallelization_parameters_input.py new file mode 100644 index 0000000000..e9204437b2 --- /dev/null +++ b/func-app/graphrag/config/input_models/parallelization_parameters_input.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Parameters model.""" + +from typing_extensions import NotRequired, TypedDict + + +class ParallelizationParametersInput(TypedDict): + """LLM Parameters model.""" + + stagger: NotRequired[float | str | None] + num_threads: NotRequired[int | str | None] diff --git a/func-app/graphrag/config/input_models/query_context_config_input.py b/func-app/graphrag/config/input_models/query_context_config_input.py new file mode 100644 index 0000000000..c8f3d2e783 --- /dev/null +++ b/func-app/graphrag/config/input_models/query_context_config_input.py @@ -0,0 +1,7 @@ +from typing_extensions import NotRequired, TypedDict + +class QueryContextConfigInput(TypedDict): + """The default configuration section for Cache.""" + + files: NotRequired[str] + """The root path to run query on.""" diff --git a/func-app/graphrag/config/input_models/reporting_config_input.py b/func-app/graphrag/config/input_models/reporting_config_input.py new file mode 100644 index 0000000000..a224f0b440 --- /dev/null +++ b/func-app/graphrag/config/input_models/reporting_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + +from graphrag.config.enums import ReportingType + + +class ReportingConfigInput(TypedDict): + """The default configuration section for Reporting.""" + + type: NotRequired[ReportingType | str | None] + base_dir: NotRequired[str | None] + connection_string: NotRequired[str | None] + container_name: NotRequired[str | None] + storage_account_blob_url: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/snapshots_config_input.py b/func-app/graphrag/config/input_models/snapshots_config_input.py new file mode 100644 index 0000000000..c20becb071 --- /dev/null +++ b/func-app/graphrag/config/input_models/snapshots_config_input.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class SnapshotsConfigInput(TypedDict): + """Configuration section for snapshots.""" + + graphml: NotRequired[bool | str | None] + raw_entities: NotRequired[bool | str | None] + top_level_nodes: NotRequired[bool | str | None] diff --git a/func-app/graphrag/config/input_models/storage_config_input.py b/func-app/graphrag/config/input_models/storage_config_input.py new file mode 100644 index 0000000000..cc5caf7952 --- /dev/null +++ b/func-app/graphrag/config/input_models/storage_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + +from graphrag.config.enums import StorageType + + +class StorageConfigInput(TypedDict): + """The default configuration section for Storage.""" + + type: NotRequired[StorageType | str | None] + base_dir: NotRequired[str | None] + connection_string: NotRequired[str | None] + container_name: NotRequired[str | None] + storage_account_blob_url: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/summarize_descriptions_config_input.py b/func-app/graphrag/config/input_models/summarize_descriptions_config_input.py new file mode 100644 index 0000000000..6ce756e558 --- /dev/null +++ b/func-app/graphrag/config/input_models/summarize_descriptions_config_input.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from .llm_config_input import LLMConfigInput + + +class SummarizeDescriptionsConfigInput(LLMConfigInput): + """Configuration section for description summarization.""" + + prompt: NotRequired[str | None] + max_length: NotRequired[int | str | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/text_embedding_config_input.py b/func-app/graphrag/config/input_models/text_embedding_config_input.py new file mode 100644 index 0000000000..a7e176c658 --- /dev/null +++ b/func-app/graphrag/config/input_models/text_embedding_config_input.py @@ -0,0 +1,23 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from graphrag.config.enums import ( + TextEmbeddingTarget, +) + +from .llm_config_input import LLMConfigInput + + +class TextEmbeddingConfigInput(LLMConfigInput): + """Configuration section for text embeddings.""" + + batch_size: NotRequired[int | str | None] + batch_max_tokens: NotRequired[int | str | None] + target: NotRequired[TextEmbeddingTarget | str | None] + skip: NotRequired[list[str] | str | None] + vector_store: NotRequired[dict | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/umap_config_input.py b/func-app/graphrag/config/input_models/umap_config_input.py new file mode 100644 index 0000000000..543ca385e0 --- /dev/null +++ b/func-app/graphrag/config/input_models/umap_config_input.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class UmapConfigInput(TypedDict): + """Configuration section for UMAP.""" + + enabled: NotRequired[bool | str | None] diff --git a/func-app/graphrag/config/models/__init__.py b/func-app/graphrag/config/models/__init__.py new file mode 100644 index 0000000000..f1d206ef85 --- /dev/null +++ b/func-app/graphrag/config/models/__init__.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Interfaces for Default Config parameterization.""" + +from .cache_config import CacheConfig +from .chunking_config import ChunkingConfig +from .claim_extraction_config import ClaimExtractionConfig +from .cluster_graph_config import ClusterGraphConfig +from .community_reports_config import CommunityReportsConfig +from .embed_graph_config import EmbedGraphConfig +from .entity_extraction_config import EntityExtractionConfig +from .global_search_config import GlobalSearchConfig +from .graph_rag_config import GraphRagConfig +from .input_config import InputConfig +from .llm_config import LLMConfig +from .llm_parameters import LLMParameters +from .local_search_config import LocalSearchConfig +from .parallelization_parameters import ParallelizationParameters +from .query_context_config import QueryContextConfig +from .reporting_config import ReportingConfig +from .snapshots_config import SnapshotsConfig +from .storage_config import StorageConfig +from .summarize_descriptions_config import SummarizeDescriptionsConfig +from .text_embedding_config import TextEmbeddingConfig +from .umap_config import UmapConfig +from .graphdb_config import GraphDBConfig + +__all__ = [ + "CacheConfig", + "ChunkingConfig", + "ClaimExtractionConfig", + "ClusterGraphConfig", + "CommunityReportsConfig", + "EmbedGraphConfig", + "EntityExtractionConfig", + "GlobalSearchConfig", + "GraphRagConfig", + "InputConfig", + "LLMConfig", + "LLMParameters", + "LocalSearchConfig", + "ParallelizationParameters", + "QueryContextConfig", + "ReportingConfig", + "SnapshotsConfig", + "StorageConfig", + "SummarizeDescriptionsConfig", + "TextEmbeddingConfig", + "UmapConfig", + "GraphDBConfig", +] diff --git a/func-app/graphrag/config/models/cache_config.py b/func-app/graphrag/config/models/cache_config.py new file mode 100644 index 0000000000..4589edce0b --- /dev/null +++ b/func-app/graphrag/config/models/cache_config.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import CacheType + + +class CacheConfig(BaseModel): + """The default configuration section for Cache.""" + + type: CacheType = Field( + description="The cache type to use.", default=defs.CACHE_TYPE + ) + base_dir: str = Field( + description="The base directory for the cache.", default=defs.CACHE_BASE_DIR + ) + connection_string: str | None = Field( + description="The cache connection string to use.", default=None + ) + container_name: str | None = Field( + description="The cache container name to use.", default=None + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", default=None + ) diff --git a/func-app/graphrag/config/models/chunking_config.py b/func-app/graphrag/config/models/chunking_config.py new file mode 100644 index 0000000000..4ca8a8d38c --- /dev/null +++ b/func-app/graphrag/config/models/chunking_config.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class ChunkingConfig(BaseModel): + """Configuration section for chunking.""" + + size: int = Field(description="The chunk size to use.", default=defs.CHUNK_SIZE) + overlap: int = Field( + description="The chunk overlap to use.", default=defs.CHUNK_OVERLAP + ) + group_by_columns: list[str] = Field( + description="The chunk by columns to use.", + default=defs.CHUNK_GROUP_BY_COLUMNS, + ) + strategy: dict | None = Field( + description="The chunk strategy to use, overriding the default tokenization strategy", + default=None, + ) + encoding_model: str | None = Field( + default=None, description="The encoding model to use." + ) + + def resolved_strategy(self, encoding_model: str) -> dict: + """Get the resolved chunking strategy.""" + from graphrag.index.verbs.text.chunk import ChunkStrategyType + + return self.strategy or { + "type": ChunkStrategyType.tokens, + "chunk_size": self.size, + "chunk_overlap": self.overlap, + "group_by_columns": self.group_by_columns, + "encoding_name": self.encoding_model or encoding_model, + } diff --git a/func-app/graphrag/config/models/claim_extraction_config.py b/func-app/graphrag/config/models/claim_extraction_config.py new file mode 100644 index 0000000000..a26fdad26e --- /dev/null +++ b/func-app/graphrag/config/models/claim_extraction_config.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pathlib import Path + +from pydantic import Field + +import graphrag.config.defaults as defs + +from .llm_config import LLMConfig + + +class ClaimExtractionConfig(LLMConfig): + """Configuration section for claim extraction.""" + + enabled: bool = Field( + description="Whether claim extraction is enabled.", + ) + prompt: str | None = Field( + description="The claim extraction prompt to use.", default=None + ) + description: str = Field( + description="The claim description to use.", + default=defs.CLAIM_DESCRIPTION, + ) + max_gleanings: int = Field( + description="The maximum number of entity gleanings to use.", + default=defs.CLAIM_MAX_GLEANINGS, + ) + strategy: dict | None = Field( + description="The override strategy to use.", default=None + ) + encoding_model: str | None = Field( + default=None, description="The encoding model to use." + ) + + def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: + """Get the resolved claim extraction strategy.""" + from graphrag.index.verbs.covariates.extract_covariates import ( + ExtractClaimsStrategyType, + ) + + return self.strategy or { + "type": ExtractClaimsStrategyType.graph_intelligence, + "llm": self.llm.model_dump(), + **self.parallelization.model_dump(), + "extraction_prompt": (Path(root_dir) / self.prompt) + .read_bytes() + .decode(encoding="utf-8") + if self.prompt + else None, + "claim_description": self.description, + "max_gleanings": self.max_gleanings, + "encoding_name": self.encoding_model or encoding_model, + } diff --git a/func-app/graphrag/config/models/cluster_graph_config.py b/func-app/graphrag/config/models/cluster_graph_config.py new file mode 100644 index 0000000000..3029baebcb --- /dev/null +++ b/func-app/graphrag/config/models/cluster_graph_config.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class ClusterGraphConfig(BaseModel): + """Configuration section for clustering graphs.""" + + max_cluster_size: int = Field( + description="The maximum cluster size to use.", default=defs.MAX_CLUSTER_SIZE + ) + strategy: dict | None = Field( + description="The cluster strategy to use.", default=None + ) + + def resolved_strategy(self) -> dict: + """Get the resolved cluster strategy.""" + from graphrag.index.verbs.graph.clustering import GraphCommunityStrategyType + + return self.strategy or { + "type": GraphCommunityStrategyType.leiden, + "max_cluster_size": self.max_cluster_size, + } diff --git a/func-app/graphrag/config/models/community_reports_config.py b/func-app/graphrag/config/models/community_reports_config.py new file mode 100644 index 0000000000..ab55063cec --- /dev/null +++ b/func-app/graphrag/config/models/community_reports_config.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pathlib import Path + +from pydantic import Field + +import graphrag.config.defaults as defs + +from .llm_config import LLMConfig + + +class CommunityReportsConfig(LLMConfig): + """Configuration section for community reports.""" + + prompt: str | None = Field( + description="The community report extraction prompt to use.", default=None + ) + max_length: int = Field( + description="The community report maximum length in tokens.", + default=defs.COMMUNITY_REPORT_MAX_LENGTH, + ) + max_input_length: int = Field( + description="The maximum input length in tokens to use when generating reports.", + default=defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH, + ) + strategy: dict | None = Field( + description="The override strategy to use.", default=None + ) + + def resolved_strategy(self, root_dir) -> dict: + """Get the resolved community report extraction strategy.""" + from graphrag.index.verbs.graph.report import CreateCommunityReportsStrategyType + + return self.strategy or { + "type": CreateCommunityReportsStrategyType.graph_intelligence, + "llm": self.llm.model_dump(), + **self.parallelization.model_dump(), + "extraction_prompt": (Path(root_dir) / self.prompt) + .read_bytes() + .decode(encoding="utf-8") + if self.prompt + else None, + "max_report_length": self.max_length, + "max_input_length": self.max_input_length, + } diff --git a/func-app/graphrag/config/models/embed_graph_config.py b/func-app/graphrag/config/models/embed_graph_config.py new file mode 100644 index 0000000000..8b7677ab10 --- /dev/null +++ b/func-app/graphrag/config/models/embed_graph_config.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class EmbedGraphConfig(BaseModel): + """The default configuration section for Node2Vec.""" + + enabled: bool = Field( + description="A flag indicating whether to enable node2vec.", + default=defs.NODE2VEC_ENABLED, + ) + num_walks: int = Field( + description="The node2vec number of walks.", default=defs.NODE2VEC_NUM_WALKS + ) + walk_length: int = Field( + description="The node2vec walk length.", default=defs.NODE2VEC_WALK_LENGTH + ) + window_size: int = Field( + description="The node2vec window size.", default=defs.NODE2VEC_WINDOW_SIZE + ) + iterations: int = Field( + description="The node2vec iterations.", default=defs.NODE2VEC_ITERATIONS + ) + random_seed: int = Field( + description="The node2vec random seed.", default=defs.NODE2VEC_RANDOM_SEED + ) + strategy: dict | None = Field( + description="The graph embedding strategy override.", default=None + ) + + def resolved_strategy(self) -> dict: + """Get the resolved node2vec strategy.""" + from graphrag.index.verbs.graph.embed import EmbedGraphStrategyType + + return self.strategy or { + "type": EmbedGraphStrategyType.node2vec, + "num_walks": self.num_walks, + "walk_length": self.walk_length, + "window_size": self.window_size, + "iterations": self.iterations, + "random_seed": self.iterations, + } diff --git a/func-app/graphrag/config/models/entity_extraction_config.py b/func-app/graphrag/config/models/entity_extraction_config.py new file mode 100644 index 0000000000..ca160bc4e2 --- /dev/null +++ b/func-app/graphrag/config/models/entity_extraction_config.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pathlib import Path + +from pydantic import Field + +import graphrag.config.defaults as defs + +from .llm_config import LLMConfig + + +class EntityExtractionConfig(LLMConfig): + """Configuration section for entity extraction.""" + + prompt: str | None = Field( + description="The entity extraction prompt to use.", default=None + ) + entity_types: list[str] = Field( + description="The entity extraction entity types to use.", + default=defs.ENTITY_EXTRACTION_ENTITY_TYPES, + ) + max_gleanings: int = Field( + description="The maximum number of entity gleanings to use.", + default=defs.ENTITY_EXTRACTION_MAX_GLEANINGS, + ) + strategy: dict | None = Field( + description="Override the default entity extraction strategy", default=None + ) + encoding_model: str | None = Field( + default=None, description="The encoding model to use." + ) + + def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: + """Get the resolved entity extraction strategy.""" + from graphrag.index.verbs.entities.extraction import ExtractEntityStrategyType + + return self.strategy or { + "type": ExtractEntityStrategyType.graph_intelligence, + "llm": self.llm.model_dump(), + **self.parallelization.model_dump(), + "extraction_prompt": (Path(root_dir) / self.prompt) + .read_bytes() + .decode(encoding="utf-8") + if self.prompt + else None, + "max_gleanings": self.max_gleanings, + # It's prechunked in create_base_text_units + "encoding_name": self.encoding_model or encoding_model, + "prechunked": True, + } diff --git a/func-app/graphrag/config/models/global_search_config.py b/func-app/graphrag/config/models/global_search_config.py new file mode 100644 index 0000000000..9eb388c373 --- /dev/null +++ b/func-app/graphrag/config/models/global_search_config.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class GlobalSearchConfig(BaseModel): + """The default configuration section for Cache.""" + + temperature: float | None = Field( + description="The temperature to use for token generation.", + default=defs.GLOBAL_SEARCH_LLM_TEMPERATURE, + ) + top_p: float | None = Field( + description="The top-p value to use for token generation.", + default=defs.GLOBAL_SEARCH_LLM_TOP_P, + ) + n: int | None = Field( + description="The number of completions to generate.", + default=defs.GLOBAL_SEARCH_LLM_N, + ) + max_tokens: int = Field( + description="The maximum context size in tokens.", + default=defs.GLOBAL_SEARCH_MAX_TOKENS, + ) + data_max_tokens: int = Field( + description="The data llm maximum tokens.", + default=defs.GLOBAL_SEARCH_DATA_MAX_TOKENS, + ) + map_max_tokens: int = Field( + description="The map llm maximum tokens.", + default=defs.GLOBAL_SEARCH_MAP_MAX_TOKENS, + ) + reduce_max_tokens: int = Field( + description="The reduce llm maximum tokens.", + default=defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS, + ) + concurrency: int = Field( + description="The number of concurrent requests.", + default=defs.GLOBAL_SEARCH_CONCURRENCY, + ) diff --git a/func-app/graphrag/config/models/graph_rag_config.py b/func-app/graphrag/config/models/graph_rag_config.py new file mode 100644 index 0000000000..e7249a9016 --- /dev/null +++ b/func-app/graphrag/config/models/graph_rag_config.py @@ -0,0 +1,158 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from devtools import pformat +from graphrag.config.models.graphdb_config import GraphDBConfig +from pydantic import Field + +import graphrag.config.defaults as defs + +from .cache_config import CacheConfig +from .chunking_config import ChunkingConfig +from .claim_extraction_config import ClaimExtractionConfig +from .cluster_graph_config import ClusterGraphConfig +from .community_reports_config import CommunityReportsConfig +from .embed_graph_config import EmbedGraphConfig +from .entity_extraction_config import EntityExtractionConfig +from .global_search_config import GlobalSearchConfig +from .input_config import InputConfig +from .llm_config import LLMConfig +from .local_search_config import LocalSearchConfig +from .query_context_config import QueryContextConfig +from .reporting_config import ReportingConfig +from .snapshots_config import SnapshotsConfig +from .storage_config import StorageConfig +from .summarize_descriptions_config import ( + SummarizeDescriptionsConfig, +) +from .text_embedding_config import TextEmbeddingConfig +from .umap_config import UmapConfig + + +class GraphRagConfig(LLMConfig): + """Base class for the Default-Configuration parameterization settings.""" + + def __repr__(self) -> str: + """Get a string representation.""" + return pformat(self, highlight=False) + + def __str__(self): + """Get a string representation.""" + return self.model_dump_json(indent=4) + + root_dir: str = Field( + description="The root directory for the configuration.", default=None + ) + + reporting: ReportingConfig = Field( + description="The reporting configuration.", default=ReportingConfig() + ) + """The reporting configuration.""" + + storage: StorageConfig = Field( + description="The storage configuration.", default=StorageConfig() + ) + """The storage configuration.""" + + cache: CacheConfig = Field( + description="The cache configuration.", default=CacheConfig() + ) + """The cache configuration.""" + + input: InputConfig = Field( + description="The input configuration.", default=InputConfig() + ) + """The input configuration.""" + + embed_graph: EmbedGraphConfig = Field( + description="Graph embedding configuration.", + default=EmbedGraphConfig(), + ) + """Graph Embedding configuration.""" + + embeddings: TextEmbeddingConfig = Field( + description="The embeddings LLM configuration to use.", + default=TextEmbeddingConfig(), + ) + """The embeddings LLM configuration to use.""" + + chunks: ChunkingConfig = Field( + description="The chunking configuration to use.", + default=ChunkingConfig(), + ) + """The chunking configuration to use.""" + + snapshots: SnapshotsConfig = Field( + description="The snapshots configuration to use.", + default=SnapshotsConfig(), + ) + """The snapshots configuration to use.""" + + entity_extraction: EntityExtractionConfig = Field( + description="The entity extraction configuration to use.", + default=EntityExtractionConfig(), + ) + """The entity extraction configuration to use.""" + + summarize_descriptions: SummarizeDescriptionsConfig = Field( + description="The description summarization configuration to use.", + default=SummarizeDescriptionsConfig(), + ) + """The description summarization configuration to use.""" + + community_reports: CommunityReportsConfig = Field( + description="The community reports configuration to use.", + default=CommunityReportsConfig(), + ) + """The community reports configuration to use.""" + + claim_extraction: ClaimExtractionConfig = Field( + description="The claim extraction configuration to use.", + default=ClaimExtractionConfig( + enabled=defs.CLAIM_EXTRACTION_ENABLED, + ), + ) + """The claim extraction configuration to use.""" + + cluster_graph: ClusterGraphConfig = Field( + description="The cluster graph configuration to use.", + default=ClusterGraphConfig(), + ) + """The cluster graph configuration to use.""" + + umap: UmapConfig = Field( + description="The UMAP configuration to use.", default=UmapConfig() + ) + """The UMAP configuration to use.""" + + local_search: LocalSearchConfig = Field( + description="The local search configuration.", default=LocalSearchConfig() + ) + """The local search configuration.""" + + global_search: GlobalSearchConfig = Field( + description="The global search configuration.", default=GlobalSearchConfig() + ) + """The global search configuration.""" + + encoding_model: str = Field( + description="The encoding model to use.", default=defs.ENCODING_MODEL + ) + """The encoding model to use.""" + + skip_workflows: list[str] = Field( + description="The workflows to skip, usually for testing reasons.", default=[] + ) + """The workflows to skip, usually for testing reasons.""" + + query_context: QueryContextConfig = Field( + description="The query context to use.", default=[] + ) + """The query context to use.""" + + graphdb: GraphDBConfig = Field( + description="The parameters to use graphdb.", default=[] + ) + """The parameters to use graphdb.""" \ No newline at end of file diff --git a/func-app/graphrag/config/models/graphdb_config.py b/func-app/graphrag/config/models/graphdb_config.py new file mode 100644 index 0000000000..8ee0f9d276 --- /dev/null +++ b/func-app/graphrag/config/models/graphdb_config.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class GraphDBConfig(BaseModel): + account_name: str|None = Field( + description="Graphdb account name", + default=None + ) + account_key: str|None = Field( + description="Graphdb account key", + default=None + ) + username: str|None = Field( + description="Graphdb username", + default=None + ) + enabled: bool = Field( + description="Flag to enable querying into graphdb", + default=False + ) + + cosmos_url: str|None = Field( + description="Cosmos account url", + default=None, + ) + + gremlin_url: str|None = Field( + description="Gremlin db url", + default=None, + ) \ No newline at end of file diff --git a/func-app/graphrag/config/models/input_config.py b/func-app/graphrag/config/models/input_config.py new file mode 100644 index 0000000000..f9e5847af6 --- /dev/null +++ b/func-app/graphrag/config/models/input_config.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import InputFileType, InputType + + +class InputConfig(BaseModel): + """The default configuration section for Input.""" + + type: InputType = Field( + description="The input type to use.", default=defs.INPUT_TYPE + ) + file_type: InputFileType = Field( + description="The input file type to use.", default=defs.INPUT_FILE_TYPE + ) + base_dir: str = Field( + description="The input base directory to use.", default=defs.INPUT_BASE_DIR + ) + connection_string: str | None = Field( + description="The azure blob storage connection string to use.", default=None + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", default=None + ) + container_name: str | None = Field( + description="The azure blob storage container name to use.", default=None + ) + encoding: str | None = Field( + description="The input file encoding to use.", + default=defs.INPUT_FILE_ENCODING, + ) + file_pattern: str = Field( + description="The input file pattern to use.", default=defs.INPUT_TEXT_PATTERN + ) + file_filter: dict[str, str] | None = Field( + description="The optional file filter for the input files.", default=None + ) + source_column: str | None = Field( + description="The input source column to use.", default=None + ) + timestamp_column: str | None = Field( + description="The input timestamp column to use.", default=None + ) + timestamp_format: str | None = Field( + description="The input timestamp format to use.", default=None + ) + text_column: str = Field( + description="The input text column to use.", default=defs.INPUT_TEXT_COLUMN + ) + title_column: str | None = Field( + description="The input title column to use.", default=None + ) + document_attribute_columns: list[str] = Field( + description="The document attribute columns to use.", default=[] + ) diff --git a/func-app/graphrag/config/models/llm_config.py b/func-app/graphrag/config/models/llm_config.py new file mode 100644 index 0000000000..62c193b0c5 --- /dev/null +++ b/func-app/graphrag/config/models/llm_config.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from datashaper import AsyncType +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + +from .llm_parameters import LLMParameters +from .parallelization_parameters import ParallelizationParameters + + +class LLMConfig(BaseModel): + """Base class for LLM-configured steps.""" + + llm: LLMParameters = Field( + description="The LLM configuration to use.", default=LLMParameters() + ) + parallelization: ParallelizationParameters = Field( + description="The parallelization configuration to use.", + default=ParallelizationParameters(), + ) + async_mode: AsyncType = Field( + description="The async mode to use.", default=defs.ASYNC_MODE + ) diff --git a/func-app/graphrag/config/models/llm_parameters.py b/func-app/graphrag/config/models/llm_parameters.py new file mode 100644 index 0000000000..df81138a2f --- /dev/null +++ b/func-app/graphrag/config/models/llm_parameters.py @@ -0,0 +1,87 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Parameters model.""" + +from pydantic import BaseModel, ConfigDict, Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import LLMType + + +class LLMParameters(BaseModel): + """LLM Parameters model.""" + + model_config = ConfigDict(protected_namespaces=(), extra="allow") + api_key: str | None = Field( + description="The API key to use for the LLM service.", + default=None, + ) + type: LLMType = Field( + description="The type of LLM model to use.", default=defs.LLM_TYPE + ) + model: str = Field(description="The LLM model to use.", default=defs.LLM_MODEL) + max_tokens: int | None = Field( + description="The maximum number of tokens to generate.", + default=defs.LLM_MAX_TOKENS, + ) + temperature: float | None = Field( + description="The temperature to use for token generation.", + default=defs.LLM_TEMPERATURE, + ) + top_p: float | None = Field( + description="The top-p value to use for token generation.", + default=defs.LLM_TOP_P, + ) + n: int | None = Field( + description="The number of completions to generate.", + default=defs.LLM_N, + ) + request_timeout: float = Field( + description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT + ) + api_base: str | None = Field( + description="The base URL for the LLM API.", default=None + ) + api_version: str | None = Field( + description="The version of the LLM API to use.", default=None + ) + organization: str | None = Field( + description="The organization to use for the LLM service.", default=None + ) + proxy: str | None = Field( + description="The proxy to use for the LLM service.", default=None + ) + cognitive_services_endpoint: str | None = Field( + description="The endpoint to reach cognitives services.", default=None + ) + deployment_name: str | None = Field( + description="The deployment name to use for the LLM service.", default=None + ) + model_supports_json: bool | None = Field( + description="Whether the model supports JSON output mode.", default=None + ) + tokens_per_minute: int = Field( + description="The number of tokens per minute to use for the LLM service.", + default=defs.LLM_TOKENS_PER_MINUTE, + ) + requests_per_minute: int = Field( + description="The number of requests per minute to use for the LLM service.", + default=defs.LLM_REQUESTS_PER_MINUTE, + ) + max_retries: int = Field( + description="The maximum number of retries to use for the LLM service.", + default=defs.LLM_MAX_RETRIES, + ) + max_retry_wait: float = Field( + description="The maximum retry wait to use for the LLM service.", + default=defs.LLM_MAX_RETRY_WAIT, + ) + sleep_on_rate_limit_recommendation: bool = Field( + description="Whether to sleep on rate limit recommendations.", + default=defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION, + ) + concurrent_requests: int = Field( + description="Whether to use concurrent requests for the LLM service.", + default=defs.LLM_CONCURRENT_REQUESTS, + ) diff --git a/func-app/graphrag/config/models/local_search_config.py b/func-app/graphrag/config/models/local_search_config.py new file mode 100644 index 0000000000..c41344daef --- /dev/null +++ b/func-app/graphrag/config/models/local_search_config.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class LocalSearchConfig(BaseModel): + """The default configuration section for Cache.""" + + text_unit_prop: float = Field( + description="The text unit proportion.", + default=defs.LOCAL_SEARCH_TEXT_UNIT_PROP, + ) + community_prop: float = Field( + description="The community proportion.", + default=defs.LOCAL_SEARCH_COMMUNITY_PROP, + ) + conversation_history_max_turns: int = Field( + description="The conversation history maximum turns.", + default=defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS, + ) + top_k_entities: int = Field( + description="The top k mapped entities.", + default=defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES, + ) + top_k_relationships: int = Field( + description="The top k mapped relations.", + default=defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS, + ) + temperature: float | None = Field( + description="The temperature to use for token generation.", + default=defs.LOCAL_SEARCH_LLM_TEMPERATURE, + ) + top_p: float | None = Field( + description="The top-p value to use for token generation.", + default=defs.LOCAL_SEARCH_LLM_TOP_P, + ) + n: int | None = Field( + description="The number of completions to generate.", + default=defs.LOCAL_SEARCH_LLM_N, + ) + max_tokens: int = Field( + description="The maximum tokens.", default=defs.LOCAL_SEARCH_MAX_TOKENS + ) + llm_max_tokens: int = Field( + description="The LLM maximum tokens.", default=defs.LOCAL_SEARCH_LLM_MAX_TOKENS + ) diff --git a/func-app/graphrag/config/models/parallelization_parameters.py b/func-app/graphrag/config/models/parallelization_parameters.py new file mode 100644 index 0000000000..80a85b8639 --- /dev/null +++ b/func-app/graphrag/config/models/parallelization_parameters.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Parameters model.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class ParallelizationParameters(BaseModel): + """LLM Parameters model.""" + + stagger: float = Field( + description="The stagger to use for the LLM service.", + default=defs.PARALLELIZATION_STAGGER, + ) + num_threads: int = Field( + description="The number of threads to use for the LLM service.", + default=defs.PARALLELIZATION_NUM_THREADS, + ) diff --git a/func-app/graphrag/config/models/query_context_config.py b/func-app/graphrag/config/models/query_context_config.py new file mode 100644 index 0000000000..15626efba9 --- /dev/null +++ b/func-app/graphrag/config/models/query_context_config.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class QueryContextConfig(BaseModel): + """The default configuration section for Cache.""" + files: list[str] = Field( + description="The list of the files on which query should be run.", + default=[] + ) \ No newline at end of file diff --git a/func-app/graphrag/config/models/reporting_config.py b/func-app/graphrag/config/models/reporting_config.py new file mode 100644 index 0000000000..35e86cf5da --- /dev/null +++ b/func-app/graphrag/config/models/reporting_config.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import ReportingType + + +class ReportingConfig(BaseModel): + """The default configuration section for Reporting.""" + + type: ReportingType = Field( + description="The reporting type to use.", default=defs.REPORTING_TYPE + ) + base_dir: str = Field( + description="The base directory for reporting.", + default=defs.REPORTING_BASE_DIR, + ) + connection_string: str | None = Field( + description="The reporting connection string to use.", default=None + ) + container_name: str | None = Field( + description="The reporting container name to use.", default=None + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", default=None + ) diff --git a/func-app/graphrag/config/models/snapshots_config.py b/func-app/graphrag/config/models/snapshots_config.py new file mode 100644 index 0000000000..08293fb7a7 --- /dev/null +++ b/func-app/graphrag/config/models/snapshots_config.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class SnapshotsConfig(BaseModel): + """Configuration section for snapshots.""" + + graphml: bool = Field( + description="A flag indicating whether to take snapshots of GraphML.", + default=defs.SNAPSHOTS_GRAPHML, + ) + raw_entities: bool = Field( + description="A flag indicating whether to take snapshots of raw entities.", + default=defs.SNAPSHOTS_RAW_ENTITIES, + ) + top_level_nodes: bool = Field( + description="A flag indicating whether to take snapshots of top-level nodes.", + default=defs.SNAPSHOTS_TOP_LEVEL_NODES, + ) diff --git a/func-app/graphrag/config/models/storage_config.py b/func-app/graphrag/config/models/storage_config.py new file mode 100644 index 0000000000..b3b5c70fe0 --- /dev/null +++ b/func-app/graphrag/config/models/storage_config.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import StorageType + + +class StorageConfig(BaseModel): + """The default configuration section for Storage.""" + + type: StorageType = Field( + description="The storage type to use.", default=defs.STORAGE_TYPE + ) + base_dir: str = Field( + description="The base directory for the storage.", + default=defs.STORAGE_BASE_DIR, + ) + connection_string: str | None = Field( + description="The storage connection string to use.", default=None + ) + container_name: str | None = Field( + description="The storage container name to use.", default=None + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", default=None + ) + overwrite: bool = Field( + description="If true, don't throw error overwrite existing containers otherwise throw error", default= False + ) diff --git a/func-app/graphrag/config/models/summarize_descriptions_config.py b/func-app/graphrag/config/models/summarize_descriptions_config.py new file mode 100644 index 0000000000..9747d949c6 --- /dev/null +++ b/func-app/graphrag/config/models/summarize_descriptions_config.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pathlib import Path + +from pydantic import Field + +import graphrag.config.defaults as defs + +from .llm_config import LLMConfig + + +class SummarizeDescriptionsConfig(LLMConfig): + """Configuration section for description summarization.""" + + prompt: str | None = Field( + description="The description summarization prompt to use.", default=None + ) + max_length: int = Field( + description="The description summarization maximum length.", + default=defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH, + ) + strategy: dict | None = Field( + description="The override strategy to use.", default=None + ) + + def resolved_strategy(self, root_dir: str) -> dict: + """Get the resolved description summarization strategy.""" + from graphrag.index.verbs.entities.summarize import SummarizeStrategyType + + return self.strategy or { + "type": SummarizeStrategyType.graph_intelligence, + "llm": self.llm.model_dump(), + **self.parallelization.model_dump(), + "summarize_prompt": (Path(root_dir) / self.prompt) + .read_bytes() + .decode(encoding="utf-8") + if self.prompt + else None, + "max_summary_length": self.max_length, + } diff --git a/func-app/graphrag/config/models/text_embedding_config.py b/func-app/graphrag/config/models/text_embedding_config.py new file mode 100644 index 0000000000..5c2fcdb86e --- /dev/null +++ b/func-app/graphrag/config/models/text_embedding_config.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import TextEmbeddingTarget + +from .llm_config import LLMConfig + + +class TextEmbeddingConfig(LLMConfig): + """Configuration section for text embeddings.""" + + batch_size: int = Field( + description="The batch size to use.", default=defs.EMBEDDING_BATCH_SIZE + ) + batch_max_tokens: int = Field( + description="The batch max tokens to use.", + default=defs.EMBEDDING_BATCH_MAX_TOKENS, + ) + target: TextEmbeddingTarget = Field( + description="The target to use. 'all' or 'required'.", + default=defs.EMBEDDING_TARGET, + ) + skip: list[str] = Field(description="The specific embeddings to skip.", default=[]) + vector_store: dict | None = Field( + description="The vector storage configuration", default=None + ) + strategy: dict | None = Field( + description="The override strategy to use.", default=None + ) + + def resolved_strategy(self) -> dict: + """Get the resolved text embedding strategy.""" + from graphrag.index.verbs.text.embed import TextEmbedStrategyType + + return self.strategy or { + "type": TextEmbedStrategyType.openai, + "llm": self.llm.model_dump(), + **self.parallelization.model_dump(), + "batch_size": self.batch_size, + "batch_max_tokens": self.batch_max_tokens, + } diff --git a/func-app/graphrag/config/models/umap_config.py b/func-app/graphrag/config/models/umap_config.py new file mode 100644 index 0000000000..1d9bd93ead --- /dev/null +++ b/func-app/graphrag/config/models/umap_config.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class UmapConfig(BaseModel): + """Configuration section for UMAP.""" + + enabled: bool = Field( + description="A flag indicating whether to enable UMAP.", + default=defs.UMAP_ENABLED, + ) diff --git a/func-app/graphrag/config/read_dotenv.py b/func-app/graphrag/config/read_dotenv.py new file mode 100644 index 0000000000..7e041757b3 --- /dev/null +++ b/func-app/graphrag/config/read_dotenv.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the read_dotenv utility.""" + +import logging +import os +from pathlib import Path + +from dotenv import dotenv_values + +log = logging.getLogger(__name__) + + +def read_dotenv(root: str) -> None: + """Read a .env file in the given root path.""" + env_path = Path(root) / ".env" + if env_path.exists(): + log.info("Loading pipeline .env file") + env_config = dotenv_values(f"{env_path}") + for key, value in env_config.items(): + if key not in os.environ: + os.environ[key] = value or "" + else: + log.info("No .env file found at %s", root) diff --git a/func-app/graphrag/index/__init__.py b/func-app/graphrag/index/__init__.py new file mode 100644 index 0000000000..38ab263620 --- /dev/null +++ b/func-app/graphrag/index/__init__.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine package root.""" + +from .cache import PipelineCache +from .config import ( + PipelineBlobCacheConfig, + PipelineBlobReportingConfig, + PipelineBlobStorageConfig, + PipelineCacheConfig, + PipelineCacheConfigTypes, + PipelineConfig, + PipelineConsoleReportingConfig, + PipelineCSVInputConfig, + PipelineFileCacheConfig, + PipelineFileReportingConfig, + PipelineFileStorageConfig, + PipelineInputConfig, + PipelineInputConfigTypes, + PipelineMemoryCacheConfig, + PipelineMemoryStorageConfig, + PipelineNoneCacheConfig, + PipelineReportingConfig, + PipelineReportingConfigTypes, + PipelineStorageConfig, + PipelineStorageConfigTypes, + PipelineTextInputConfig, + PipelineWorkflowConfig, + PipelineWorkflowReference, + PipelineWorkflowStep, +) +from .create_pipeline_config import create_pipeline_config +from .errors import ( + NoWorkflowsDefinedError, + UndefinedWorkflowError, + UnknownWorkflowError, +) +from .load_pipeline_config import load_pipeline_config +from .run import run_pipeline, run_pipeline_with_config +from graphrag.common.storage import PipelineStorage + +__all__ = [ + "NoWorkflowsDefinedError", + "PipelineBlobCacheConfig", + "PipelineBlobCacheConfig", + "PipelineBlobReportingConfig", + "PipelineBlobStorageConfig", + "PipelineCSVInputConfig", + "PipelineCache", + "PipelineCacheConfig", + "PipelineCacheConfigTypes", + "PipelineConfig", + "PipelineConsoleReportingConfig", + "PipelineFileCacheConfig", + "PipelineFileReportingConfig", + "PipelineFileStorageConfig", + "PipelineInputConfig", + "PipelineInputConfigTypes", + "PipelineMemoryCacheConfig", + "PipelineMemoryStorageConfig", + "PipelineNoneCacheConfig", + "PipelineReportingConfig", + "PipelineReportingConfigTypes", + "PipelineStorage", + "PipelineStorageConfig", + "PipelineStorageConfigTypes", + "PipelineTextInputConfig", + "PipelineWorkflowConfig", + "PipelineWorkflowReference", + "PipelineWorkflowStep", + "UndefinedWorkflowError", + "UnknownWorkflowError", + "create_pipeline_config", + "load_pipeline_config", + "run_pipeline", + "run_pipeline_with_config", +] diff --git a/func-app/graphrag/index/__main__.py b/func-app/graphrag/index/__main__.py new file mode 100644 index 0000000000..de2c156a69 --- /dev/null +++ b/func-app/graphrag/index/__main__.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine package root.""" + +import argparse + +from .cli import index_cli + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + help="The configuration yaml file to use when running the pipeline", + required=False, + type=str, + ) + parser.add_argument( + "-v", + "--verbose", + help="Runs the pipeline with verbose logging", + action="store_true", + ) + parser.add_argument( + "--memprofile", + help="Runs the pipeline with memory profiling", + action="store_true", + ) + parser.add_argument( + "--root", + help="If no configuration is defined, the root directory to use for input data and output data. Default value: the current directory", + # Only required if config is not defined + required=False, + default=".", + type=str, + ) + parser.add_argument( + "--resume", + help="Resume a given data run leveraging Parquet output files.", + # Only required if config is not defined + required=False, + default=None, + type=str, + ) + parser.add_argument( + "--reporter", + help="The progress reporter to use. Valid values are 'rich', 'print', or 'none'", + type=str, + ) + parser.add_argument( + "--emit", + help="The data formats to emit, comma-separated. Valid values are 'parquet' and 'csv'. default='parquet,csv'", + type=str, + ) + parser.add_argument( + "--context_id", + required=False, + help="Context id to activate or deactivate.", + type=str + ) + parser.add_argument( + "--context_operation", + help="Context operation activate or deactivate.", + required=False, + # Only required if contextId is provided + type=str + ) + parser.add_argument( + "--dryrun", + help="Run the pipeline without actually executing any steps and inspect the configuration.", + action="store_true", + ) + parser.add_argument("--nocache", help="Disable LLM cache.", action="store_true") + parser.add_argument( + "--init", + help="Create an initial configuration in the given path.", + action="store_true", + ) + parser.add_argument( + "--overlay_defaults", + help="Overlay default configuration values on a provided configuration file (--config).", + action="store_true", + ) + parser.add_argument( + "--community_level", + help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities", + type=int, + default=2, + ) + parser.add_argument( + "--use_kusto_community_reports", + help="If enabled community reports are loaded into Kusto during activation", + action="store_true", + ) + parser.add_argument( + "--optimized_search", + help="Runs optimized search and export artifacts", + type=bool, + default=False, + ) + + args = parser.parse_args() + + if args.overlay_defaults and not args.config: + parser.error("--overlay-defaults requires --config") + + index_cli( + root=args.root, + verbose=args.verbose or False, + resume=args.resume, + memprofile=args.memprofile or False, + nocache=args.nocache or False, + reporter=args.reporter, + config=args.config, + emit=args.emit, + dryrun=args.dryrun or False, + init=args.init or False, + overlay_defaults=args.overlay_defaults or False, + cli=True, + context_id=args.context_id, + context_operation=args.context_operation, + community_level=args.community_level, + use_kusto_community_reports=args.use_kusto_community_reports, + optimized_search=args.optimized_search + ) diff --git a/func-app/graphrag/index/bootstrap.py b/func-app/graphrag/index/bootstrap.py new file mode 100644 index 0000000000..398ec88b20 --- /dev/null +++ b/func-app/graphrag/index/bootstrap.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Bootstrap definition.""" + +import warnings + +# Ignore warnings from numba +warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") +warnings.filterwarnings("ignore", message=".*Use no seed for parallelism.*") + +initialized_nltk = False + + +def bootstrap(): + """Bootstrap definition.""" + global initialized_nltk + if not initialized_nltk: + import nltk + from nltk.corpus import wordnet as wn + + nltk.download("punkt") + nltk.download("averaged_perceptron_tagger") + nltk.download("maxent_ne_chunker") + nltk.download("words") + nltk.download("wordnet") + wn.ensure_loaded() + initialized_nltk = True diff --git a/func-app/graphrag/index/cache/__init__.py b/func-app/graphrag/index/cache/__init__.py new file mode 100644 index 0000000000..42ebb22994 --- /dev/null +++ b/func-app/graphrag/index/cache/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine cache package root.""" + +from .json_pipeline_cache import JsonPipelineCache +from .load_cache import load_cache +from .memory_pipeline_cache import InMemoryCache +from .noop_pipeline_cache import NoopPipelineCache +from .pipeline_cache import PipelineCache + +__all__ = [ + "InMemoryCache", + "JsonPipelineCache", + "NoopPipelineCache", + "PipelineCache", + "load_cache", +] diff --git a/func-app/graphrag/index/cache/json_pipeline_cache.py b/func-app/graphrag/index/cache/json_pipeline_cache.py new file mode 100644 index 0000000000..30e73fedc6 --- /dev/null +++ b/func-app/graphrag/index/cache/json_pipeline_cache.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'FilePipelineCache' model.""" + +import json +from typing import Any + +from graphrag.common.storage.typing import PipelineStorage + +from .pipeline_cache import PipelineCache + + +class JsonPipelineCache(PipelineCache): + """File pipeline cache class definition.""" + + _storage: PipelineStorage + _encoding: str + + def __init__(self, storage: PipelineStorage, encoding="utf-8"): + """Init method definition.""" + self._storage = storage + self._encoding = encoding + + async def get(self, key: str) -> str | None: + """Get method definition.""" + if await self.has(key): + try: + data = await self._storage.get(key, encoding=self._encoding) + data = json.loads(data) + except UnicodeDecodeError: + await self._storage.delete(key) + return None + except json.decoder.JSONDecodeError: + await self._storage.delete(key) + return None + else: + return data.get("result") + + return None + + async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None: + """Set method definition.""" + if value is None: + return + data = {"result": value, **(debug_data or {})} + await self._storage.set(key, json.dumps(data), encoding=self._encoding) + + async def has(self, key: str) -> bool: + """Has method definition.""" + return await self._storage.has(key) + + async def delete(self, key: str) -> None: + """Delete method definition.""" + if await self.has(key): + await self._storage.delete(key) + + async def clear(self) -> None: + """Clear method definition.""" + await self._storage.clear() + + def child(self, name: str) -> "JsonPipelineCache": + """Child method definition.""" + return JsonPipelineCache(self._storage.child(name), encoding=self._encoding) diff --git a/func-app/graphrag/index/cache/load_cache.py b/func-app/graphrag/index/cache/load_cache.py new file mode 100644 index 0000000000..1a97b2e4de --- /dev/null +++ b/func-app/graphrag/index/cache/load_cache.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load_cache method definition.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from graphrag.config.enums import CacheType +from graphrag.index.config.cache import ( + PipelineBlobCacheConfig, + PipelineFileCacheConfig, +) +from graphrag.common.storage import BlobPipelineStorage, FilePipelineStorage + +if TYPE_CHECKING: + from graphrag.index.config import ( + PipelineCacheConfig, + ) + +from .json_pipeline_cache import JsonPipelineCache +from .memory_pipeline_cache import create_memory_cache +from .noop_pipeline_cache import NoopPipelineCache + + +def load_cache(config: PipelineCacheConfig | None, root_dir: str | None): + """Load the cache from the given config.""" + if config is None: + return NoopPipelineCache() + + match config.type: + case CacheType.none: + return NoopPipelineCache() + case CacheType.memory: + return create_memory_cache() + case CacheType.file: + config = cast(PipelineFileCacheConfig, config) + storage = FilePipelineStorage(root_dir).child(config.base_dir) + return JsonPipelineCache(storage) + case CacheType.blob: + config = cast(PipelineBlobCacheConfig, config) + storage = BlobPipelineStorage( + config.connection_string, + config.container_name, + storage_account_blob_url=config.storage_account_blob_url, + ).child(config.base_dir) + return JsonPipelineCache(storage) + case _: + msg = f"Unknown cache type: {config.type}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/cache/memory_pipeline_cache.py b/func-app/graphrag/index/cache/memory_pipeline_cache.py new file mode 100644 index 0000000000..fa42f3f921 --- /dev/null +++ b/func-app/graphrag/index/cache/memory_pipeline_cache.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'InMemoryCache' model.""" + +from typing import Any + +from .pipeline_cache import PipelineCache + + +class InMemoryCache(PipelineCache): + """In memory cache class definition.""" + + _cache: dict[str, Any] + _name: str + + def __init__(self, name: str | None = None): + """Init method definition.""" + self._cache = {} + self._name = name or "" + + async def get(self, key: str) -> Any: + """Get the value for the given key. + + Args: + - key - The key to get the value for. + - as_bytes - Whether or not to return the value as bytes. + + Returns + ------- + - output - The value for the given key. + """ + key = self._create_cache_key(key) + return self._cache.get(key) + + async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None: + """Set the value for the given key. + + Args: + - key - The key to set the value for. + - value - The value to set. + """ + key = self._create_cache_key(key) + self._cache[key] = value + + async def has(self, key: str) -> bool: + """Return True if the given key exists in the storage. + + Args: + - key - The key to check for. + + Returns + ------- + - output - True if the key exists in the storage, False otherwise. + """ + key = self._create_cache_key(key) + return key in self._cache + + async def delete(self, key: str) -> None: + """Delete the given key from the storage. + + Args: + - key - The key to delete. + """ + key = self._create_cache_key(key) + del self._cache[key] + + async def clear(self) -> None: + """Clear the storage.""" + self._cache.clear() + + def child(self, name: str) -> PipelineCache: + """Create a sub cache with the given name.""" + return InMemoryCache(name) + + def _create_cache_key(self, key: str) -> str: + """Create a cache key for the given key.""" + return f"{self._name}{key}" + + +def create_memory_cache() -> PipelineCache: + """Create a memory cache.""" + return InMemoryCache() diff --git a/func-app/graphrag/index/cache/noop_pipeline_cache.py b/func-app/graphrag/index/cache/noop_pipeline_cache.py new file mode 100644 index 0000000000..b7c3e60fdd --- /dev/null +++ b/func-app/graphrag/index/cache/noop_pipeline_cache.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Module containing the NoopPipelineCache implementation.""" + +from typing import Any + +from .pipeline_cache import PipelineCache + + +class NoopPipelineCache(PipelineCache): + """A no-op implementation of the pipeline cache, usually useful for testing.""" + + async def get(self, key: str) -> Any: + """Get the value for the given key. + + Args: + - key - The key to get the value for. + - as_bytes - Whether or not to return the value as bytes. + + Returns + ------- + - output - The value for the given key. + """ + return None + + async def set( + self, key: str, value: str | bytes | None, debug_data: dict | None = None + ) -> None: + """Set the value for the given key. + + Args: + - key - The key to set the value for. + - value - The value to set. + """ + + async def has(self, key: str) -> bool: + """Return True if the given key exists in the cache. + + Args: + - key - The key to check for. + + Returns + ------- + - output - True if the key exists in the cache, False otherwise. + """ + return False + + async def delete(self, key: str) -> None: + """Delete the given key from the cache. + + Args: + - key - The key to delete. + """ + + async def clear(self) -> None: + """Clear the cache.""" + + def child(self, name: str) -> PipelineCache: + """Create a child cache with the given name. + + Args: + - name - The name to create the sub cache with. + """ + return self diff --git a/func-app/graphrag/index/cache/pipeline_cache.py b/func-app/graphrag/index/cache/pipeline_cache.py new file mode 100644 index 0000000000..c68c5cfb4b --- /dev/null +++ b/func-app/graphrag/index/cache/pipeline_cache.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineCache' model.""" + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import Any + + +class PipelineCache(metaclass=ABCMeta): + """Provide a cache interface for the pipeline.""" + + @abstractmethod + async def get(self, key: str) -> Any: + """Get the value for the given key. + + Args: + - key - The key to get the value for. + - as_bytes - Whether or not to return the value as bytes. + + Returns + ------- + - output - The value for the given key. + """ + + @abstractmethod + async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None: + """Set the value for the given key. + + Args: + - key - The key to set the value for. + - value - The value to set. + """ + + @abstractmethod + async def has(self, key: str) -> bool: + """Return True if the given key exists in the cache. + + Args: + - key - The key to check for. + + Returns + ------- + - output - True if the key exists in the cache, False otherwise. + """ + + @abstractmethod + async def delete(self, key: str) -> None: + """Delete the given key from the cache. + + Args: + - key - The key to delete. + """ + + @abstractmethod + async def clear(self) -> None: + """Clear the cache.""" + + @abstractmethod + def child(self, name: str) -> PipelineCache: + """Create a child cache with the given name. + + Args: + - name - The name to create the sub cache with. + """ diff --git a/func-app/graphrag/index/cli.py b/func-app/graphrag/index/cli.py new file mode 100644 index 0000000000..2695f46af9 --- /dev/null +++ b/func-app/graphrag/index/cli.py @@ -0,0 +1,356 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Main definition.""" + +import asyncio +import json +import logging +import platform +import sys +import time +import warnings +from pathlib import Path + +from graphrag.config import ( + GraphRagConfig, + create_graphrag_config, +) +from graphrag.config.enums import ContextSwitchType +from graphrag.common.utils.common_utils import is_valid_guid +from graphrag.index import PipelineConfig, create_pipeline_config +from graphrag.index.cache import NoopPipelineCache +from graphrag.common.progress import ( + NullProgressReporter, + PrintProgressReporter, + ProgressReporter, +) +from graphrag.common.progress.rich import RichProgressReporter +from graphrag.index.run import run_pipeline_with_config + +from .emit import TableEmitterType +from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT +from .graph.extractors.community_reports.prompts import COMMUNITY_REPORT_PROMPT +from .graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT +from .graph.extractors.summarize.prompts import SUMMARIZE_PROMPT +from .init_content import INIT_DOTENV, INIT_YAML + +# Ignore warnings from numba +warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*") + +log = logging.getLogger(__name__) + +def redact(input: dict) -> str: + """Sanitize the config json.""" + + # Redact any sensitive configuration + def redact_dict(input: dict) -> dict: + if not isinstance(input, dict): + return input + + result = {} + for key, value in input.items(): + if key in { + "api_key", + "connection_string", + "container_name", + "organization", + }: + if value is not None: + result[key] = f"REDACTED, length {len(value)}" + elif isinstance(value, dict): + result[key] = redact_dict(value) + elif isinstance(value, list): + result[key] = [redact_dict(i) for i in value] + else: + result[key] = value + return result + + redacted_dict = redact_dict(input) + return json.dumps(redacted_dict, indent=4) + + +def index_cli( + root: str, + init: bool, + community_level: int, + context_operation: str | None, + context_id: str | None, + verbose: bool, + resume: str | None, + memprofile: bool, + nocache: bool, + config: str | None, + emit: str | None, + dryrun: bool, + overlay_defaults: bool, + cli: bool = False, + use_kusto_community_reports: bool = False, + optimized_search: bool = False, +): + """Run the pipeline with the given config.""" + root = Path(__file__).parent.parent.parent.__str__() + run_id = resume or time.strftime("%Y%m%d-%H%M%S") + _enable_logging(root, run_id, verbose) + progress_reporter = _get_progress_reporter("none") + _initialize_project_at(root, progress_reporter) + if overlay_defaults: + pipeline_config: str | PipelineConfig = _create_default_config( + root, config, verbose, dryrun or False, progress_reporter + ) + else: + pipeline_config: str | PipelineConfig = config or _create_default_config( + root, None, verbose, dryrun or False, progress_reporter + ) + + cache = NoopPipelineCache() if nocache else None + pipeline_emit = emit.split(",") if emit else None + encountered_errors = False + logging.info("Loaded the pipeline successfully") + def _run_workflow_async() -> None: + import signal + logging.info("Step1") + def handle_signal(signum, _): + # Handle the signal here + progress_reporter.info(f"Received signal {signum}, exiting...") + progress_reporter.dispose() + for task in asyncio.all_tasks(): + task.cancel() + progress_reporter.info("All tasks cancelled. Exiting...") + + # Register signal handlers for SIGINT and SIGHUP + logging.info("Step2") + #signal.signal(signal.SIGINT, handle_signal) + + logging.info("Step3") + if sys.platform != "win32": + signal.signal(signal.SIGHUP, handle_signal) + + logging.info("Step4") + async def execute(): + nonlocal encountered_errors + async for output in run_pipeline_with_config( + pipeline_config, + run_id=run_id, + memory_profile=memprofile, + cache=cache, + progress_reporter=progress_reporter, + emit=( + [TableEmitterType(e) for e in pipeline_emit] + if pipeline_emit + else None + ), + is_resume_run=bool(resume), + context_id=context_id, + ): + if output.errors and len(output.errors) > 0: + encountered_errors = True + progress_reporter.error(output.workflow) + else: + progress_reporter.success(output.workflow) + + progress_reporter.info(str(output.result)) + + if platform.system() == "Windows": + logging.info("All set to execute the workflows on Windows") + import nest_asyncio # type: ignore Ignoring because out of windows this will cause an error + + nest_asyncio.apply() + loop = asyncio.get_event_loop() + loop.run_until_complete(execute()) + elif sys.version_info >= (3, 11): + logging.info("Step6") + import uvloop # type: ignore Ignoring because on windows this will cause an error + + with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: # type: ignore Ignoring because minor versions this will throw an error + runner.run(execute()) + else: + logging.info("Step 6") + import uvloop # type: ignore Ignoring because on windows this will cause an error + + uvloop.install() + asyncio.run(execute()) + + _run_workflow_async() + progress_reporter.stop() + if encountered_errors: + progress_reporter.error( + "Errors occurred during the pipeline run, see logs for more details." + ) + else: + progress_reporter.success("All workflows completed successfully.") + + if cli: + sys.exit(1 if encountered_errors else 0) + +def _switch_context(root: str, config: str, + reporter: ProgressReporter, context_operation: str | None, + context_id: str, community_level: int, optimized_search: bool, + use_kusto_community_reports: bool) -> None: + """Switch the context to the given context.""" + reporter.info(f"Switching context to {context_id} using operation {context_operation}") + logging.info("Switching context to {context_id}") + from graphrag.index.context_switch.contextSwitcher import ContextSwitcher + context_switcher = ContextSwitcher( + root_dir=root, + config_dir=config, + reporter=reporter, + context_id=context_id, + community_level=community_level, + data_dir=None, + optimized_search=optimized_search, + use_kusto_community_reports=use_kusto_community_reports) + if context_operation == ContextSwitchType.Activate: + context_switcher.activate() + elif context_operation == ContextSwitchType.Deactivate: + context_switcher.deactivate() + else: + msg = f"Invalid context operation {context_operation}" + raise ValueError(msg) + +def _initialize_project_at(path: str, reporter: ProgressReporter) -> None: + """Initialize the project at the given path.""" + reporter.info(f"Initializing project at {path}") + root = Path(path) + if not root.exists(): + root.mkdir(parents=True, exist_ok=True) + + settings_yaml = root / "settings/settings.yaml" + + dotenv = root / ".env" + if not dotenv.exists(): + with settings_yaml.open("wb") as file: + file.write(INIT_YAML.encode(encoding="utf-8", errors="strict")) + + with dotenv.open("wb") as file: + file.write(INIT_DOTENV.encode(encoding="utf-8", errors="strict")) + + prompts_dir = root / "prompts" + if not prompts_dir.exists(): + prompts_dir.mkdir(parents=True, exist_ok=True) + + entity_extraction = prompts_dir / "entity_extraction.txt" + if not entity_extraction.exists(): + with entity_extraction.open("wb") as file: + file.write( + GRAPH_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict") + ) + + summarize_descriptions = prompts_dir / "summarize_descriptions.txt" + if not summarize_descriptions.exists(): + with summarize_descriptions.open("wb") as file: + file.write(SUMMARIZE_PROMPT.encode(encoding="utf-8", errors="strict")) + + claim_extraction = prompts_dir / "claim_extraction.txt" + if not claim_extraction.exists(): + with claim_extraction.open("wb") as file: + file.write( + CLAIM_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict") + ) + + community_report = prompts_dir / "community_report.txt" + if not community_report.exists(): + with community_report.open("wb") as file: + file.write( + COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict") + ) + + +def _create_default_config( + root: str, + config: str | None, + verbose: bool, + dryrun: bool, + reporter: ProgressReporter, +) -> PipelineConfig: + """Overlay default values on an existing config or create a default config if none is provided.""" + if config and not Path(config).exists(): + msg = f"Configuration file {config} does not exist" + raise ValueError + + if not Path(root).exists(): + msg = f"Root directory {root} does not exist" + raise ValueError(msg) + + parameters = _read_config_parameters(root, config, reporter) + log.info( + "using default configuration: %s", + redact(parameters.model_dump()), + ) + + if verbose or dryrun: + reporter.info(f"Using default configuration: {redact(parameters.model_dump())}") + result = create_pipeline_config(parameters, verbose) + if verbose or dryrun: + reporter.info(f"Final Config: {redact(result.model_dump())}") + + if dryrun: + reporter.info("dry run complete, exiting...") + sys.exit(0) + return result + + +def _read_config_parameters(root: str, config: str | None, reporter: ProgressReporter): + _root = Path(root) + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) + if not settings_yaml.exists(): + settings_yaml = _root / "settings.yml" + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) + + if settings_yaml.exists(): + reporter.success(f"Reading settings from {settings_yaml}") + with settings_yaml.open("rb") as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + if settings_json.exists(): + reporter.success(f"Reading settings from {settings_json}") + with settings_json.open("rb") as file: + import json + + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + reporter.success("Reading settings from environment variables") + return create_graphrag_config(root_dir=root) + + +def _get_progress_reporter(reporter_type: str | None) -> ProgressReporter: + if reporter_type is None or reporter_type == "rich": + return RichProgressReporter("GraphRAG Indexer ") + if reporter_type == "print": + return PrintProgressReporter("GraphRAG Indexer ") + if reporter_type == "none": + return NullProgressReporter() + + msg = f"Invalid progress reporter type: {reporter_type}" + raise ValueError(msg) + + +def _enable_logging(root_dir: str, run_id: str, verbose: bool) -> None: + logging_file = ( + Path(root_dir) / "output" / run_id / "reports" / "indexing-engine.log" + ) + logging_file.parent.mkdir(parents=True, exist_ok=True) + + logging_file.touch(exist_ok=True) + handler = logging.StreamHandler(stream=sys.stdout) + fileHandler = logging.FileHandler(logging_file, mode="a") + logging.basicConfig( + #filename=str(logging_file), + #filemode="a", + format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + level=logging.DEBUG if verbose else logging.INFO, + handlers=[handler, fileHandler] + ) diff --git a/func-app/graphrag/index/config/__init__.py b/func-app/graphrag/index/config/__init__.py new file mode 100644 index 0000000000..ad30859b81 --- /dev/null +++ b/func-app/graphrag/index/config/__init__.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine config typing package root.""" + +from .cache import ( + PipelineBlobCacheConfig, + PipelineCacheConfig, + PipelineCacheConfigTypes, + PipelineFileCacheConfig, + PipelineMemoryCacheConfig, + PipelineNoneCacheConfig, +) +from .input import ( + PipelineCSVInputConfig, + PipelineInputConfig, + PipelineInputConfigTypes, + PipelineTextInputConfig, +) +from .pipeline import PipelineConfig +from .reporting import ( + PipelineBlobReportingConfig, + PipelineConsoleReportingConfig, + PipelineFileReportingConfig, + PipelineReportingConfig, + PipelineReportingConfigTypes, +) +from ...common.config.storage import ( + PipelineBlobStorageConfig, + PipelineFileStorageConfig, + PipelineMemoryStorageConfig, + PipelineStorageConfig, + PipelineStorageConfigTypes, +) +from .workflow import ( + PipelineWorkflowConfig, + PipelineWorkflowReference, + PipelineWorkflowStep, +) + +__all__ = [ + "PipelineBlobCacheConfig", + "PipelineBlobReportingConfig", + "PipelineBlobStorageConfig", + "PipelineCSVInputConfig", + "PipelineCacheConfig", + "PipelineCacheConfigTypes", + "PipelineCacheConfigTypes", + "PipelineCacheConfigTypes", + "PipelineConfig", + "PipelineConsoleReportingConfig", + "PipelineFileCacheConfig", + "PipelineFileReportingConfig", + "PipelineFileStorageConfig", + "PipelineInputConfig", + "PipelineInputConfigTypes", + "PipelineMemoryCacheConfig", + "PipelineMemoryCacheConfig", + "PipelineMemoryStorageConfig", + "PipelineNoneCacheConfig", + "PipelineReportingConfig", + "PipelineReportingConfigTypes", + "PipelineStorageConfig", + "PipelineStorageConfigTypes", + "PipelineTextInputConfig", + "PipelineWorkflowConfig", + "PipelineWorkflowReference", + "PipelineWorkflowStep", +] diff --git a/func-app/graphrag/index/config/cache.py b/func-app/graphrag/index/config/cache.py new file mode 100644 index 0000000000..be1053de2e --- /dev/null +++ b/func-app/graphrag/index/config/cache.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig' and 'PipelineMemoryCacheConfig' models.""" + +from __future__ import annotations + +from typing import Generic, Literal, TypeVar + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +from graphrag.config.enums import CacheType + +T = TypeVar("T") + + +class PipelineCacheConfig(BaseModel, Generic[T]): + """Represent the cache configuration for the pipeline.""" + + type: T + + +class PipelineFileCacheConfig(PipelineCacheConfig[Literal[CacheType.file]]): + """Represent the file cache configuration for the pipeline.""" + + type: Literal[CacheType.file] = CacheType.file + """The type of cache.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the cache.", default=None + ) + """The base directory for the cache.""" + + +class PipelineMemoryCacheConfig(PipelineCacheConfig[Literal[CacheType.memory]]): + """Represent the memory cache configuration for the pipeline.""" + + type: Literal[CacheType.memory] = CacheType.memory + """The type of cache.""" + + +class PipelineNoneCacheConfig(PipelineCacheConfig[Literal[CacheType.none]]): + """Represent the none cache configuration for the pipeline.""" + + type: Literal[CacheType.none] = CacheType.none + """The type of cache.""" + + +class PipelineBlobCacheConfig(PipelineCacheConfig[Literal[CacheType.blob]]): + """Represents the blob cache configuration for the pipeline.""" + + type: Literal[CacheType.blob] = CacheType.blob + """The type of cache.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the cache.", default=None + ) + """The base directory for the cache.""" + + connection_string: str | None = pydantic_Field( + description="The blob cache connection string for the cache.", default=None + ) + """The blob cache connection string for the cache.""" + + container_name: str = pydantic_Field( + description="The container name for cache", default=None + ) + """The container name for cache""" + + storage_account_blob_url: str | None = pydantic_Field( + description="The storage account blob url for cache", default=None + ) + """The storage account blob url for cache""" + + +PipelineCacheConfigTypes = ( + PipelineFileCacheConfig + | PipelineMemoryCacheConfig + | PipelineBlobCacheConfig + | PipelineNoneCacheConfig +) diff --git a/func-app/graphrag/index/config/input.py b/func-app/graphrag/index/config/input.py new file mode 100644 index 0000000000..35db357599 --- /dev/null +++ b/func-app/graphrag/index/config/input.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineInputConfig', 'PipelineCSVInputConfig' and 'PipelineTextInputConfig' models.""" + +from __future__ import annotations + +from typing import Generic, Literal, TypeVar + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +from graphrag.config.enums import InputFileType, InputType + +from .workflow import PipelineWorkflowStep + +T = TypeVar("T") + + +class PipelineInputConfig(BaseModel, Generic[T]): + """Represent the configuration for an input.""" + + file_type: T + """The file type of input.""" + + type: InputType | None = pydantic_Field( + description="The input type to use.", + default=None, + ) + """The input type to use.""" + + connection_string: str | None = pydantic_Field( + description="The blob cache connection string for the input files.", + default=None, + ) + """The blob cache connection string for the input files.""" + + storage_account_blob_url: str | None = pydantic_Field( + description="The storage account blob url for the input files.", default=None + ) + """The storage account blob url for the input files.""" + + container_name: str | None = pydantic_Field( + description="The container name for input files.", default=None + ) + """The container name for the input files.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the input files.", default=None + ) + """The base directory for the input files.""" + + file_pattern: str = pydantic_Field( + description="The regex file pattern for the input files." + ) + """The regex file pattern for the input files.""" + + file_filter: dict[str, str] | None = pydantic_Field( + description="The optional file filter for the input files.", default=None + ) + """The optional file filter for the input files.""" + + post_process: list[PipelineWorkflowStep] | None = pydantic_Field( + description="The post processing steps for the input.", default=None + ) + """The post processing steps for the input.""" + + encoding: str | None = pydantic_Field( + description="The encoding for the input files.", default=None + ) + """The encoding for the input files.""" + + +class PipelineCSVInputConfig(PipelineInputConfig[Literal[InputFileType.csv]]): + """Represent the configuration for a CSV input.""" + + file_type: Literal[InputFileType.csv] = InputFileType.csv + + source_column: str | None = pydantic_Field( + description="The column to use as the source of the document.", default=None + ) + """The column to use as the source of the document.""" + + timestamp_column: str | None = pydantic_Field( + description="The column to use as the timestamp of the document.", default=None + ) + """The column to use as the timestamp of the document.""" + + timestamp_format: str | None = pydantic_Field( + description="The format of the timestamp column, so it can be parsed correctly.", + default=None, + ) + """The format of the timestamp column, so it can be parsed correctly.""" + + text_column: str | None = pydantic_Field( + description="The column to use as the text of the document.", default=None + ) + """The column to use as the text of the document.""" + + title_column: str | None = pydantic_Field( + description="The column to use as the title of the document.", default=None + ) + """The column to use as the title of the document.""" + + +class PipelineTextInputConfig(PipelineInputConfig[Literal[InputFileType.text]]): + """Represent the configuration for a text input.""" + + file_type: Literal[InputFileType.text] = InputFileType.text + + # Text Specific + title_text_length: int | None = pydantic_Field( + description="Number of characters to use from the text as the title.", + default=None, + ) + """Number of characters to use from the text as the title.""" + + +PipelineInputConfigTypes = PipelineCSVInputConfig | PipelineTextInputConfig +"""Represent the types of inputs that can be used in a pipeline.""" diff --git a/func-app/graphrag/index/config/pipeline.py b/func-app/graphrag/index/config/pipeline.py new file mode 100644 index 0000000000..e8bbbdbf4c --- /dev/null +++ b/func-app/graphrag/index/config/pipeline.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineConfig' model.""" + +from __future__ import annotations + +from devtools import pformat +from graphrag.config.models.graphdb_config import GraphDBConfig +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +from .cache import PipelineCacheConfigTypes +from .input import PipelineInputConfigTypes +from .reporting import PipelineReportingConfigTypes +from ...common.config.storage import PipelineStorageConfigTypes +from .workflow import PipelineWorkflowReference + + +class PipelineConfig(BaseModel): + """Represent the configuration for a pipeline.""" + + def __repr__(self) -> str: + """Get a string representation.""" + return pformat(self, highlight=False) + + def __str__(self): + """Get a string representation.""" + return str(self.model_dump_json(indent=4)) + + extends: list[str] | str | None = pydantic_Field( + description="Extends another pipeline configuration", default=None + ) + """Extends another pipeline configuration""" + + input: PipelineInputConfigTypes | None = pydantic_Field( + default=None, discriminator="file_type" + ) + """The input configuration for the pipeline.""" + + reporting: PipelineReportingConfigTypes | None = pydantic_Field( + default=None, discriminator="type" + ) + """The reporting configuration for the pipeline.""" + + storage: PipelineStorageConfigTypes | None = pydantic_Field( + default=None, discriminator="type" + ) + """The storage configuration for the pipeline.""" + + cache: PipelineCacheConfigTypes | None = pydantic_Field( + default=None, discriminator="type" + ) + """The cache configuration for the pipeline.""" + + root_dir: str | None = pydantic_Field( + description="The root directory for the pipeline. All other paths will be based on this root_dir.", + default=None, + ) + """The root directory for the pipeline.""" + + workflows: list[PipelineWorkflowReference] = pydantic_Field( + description="The workflows for the pipeline.", default_factory=list + ) + """The workflows for the pipeline.""" + + graphdb_params: GraphDBConfig|None = pydantic_Field( + description="Parameters for Graphdb collection", default=None + ) + """Parameters for Graphdb collection""" diff --git a/func-app/graphrag/index/config/reporting.py b/func-app/graphrag/index/config/reporting.py new file mode 100644 index 0000000000..921e24ae4e --- /dev/null +++ b/func-app/graphrag/index/config/reporting.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineReportingConfig', 'PipelineFileReportingConfig' and 'PipelineConsoleReportingConfig' models.""" + +from __future__ import annotations + +from typing import Generic, Literal, TypeVar + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +from graphrag.config.enums import ReportingType + +T = TypeVar("T") + + +class PipelineReportingConfig(BaseModel, Generic[T]): + """Represent the reporting configuration for the pipeline.""" + + type: T + + +class PipelineFileReportingConfig(PipelineReportingConfig[Literal[ReportingType.file]]): + """Represent the file reporting configuration for the pipeline.""" + + type: Literal[ReportingType.file] = ReportingType.file + """The type of reporting.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the reporting.", default=None + ) + """The base directory for the reporting.""" + + +class PipelineConsoleReportingConfig( + PipelineReportingConfig[Literal[ReportingType.console]] +): + """Represent the console reporting configuration for the pipeline.""" + + type: Literal[ReportingType.console] = ReportingType.console + """The type of reporting.""" + + +class PipelineBlobReportingConfig(PipelineReportingConfig[Literal[ReportingType.blob]]): + """Represents the blob reporting configuration for the pipeline.""" + + type: Literal[ReportingType.blob] = ReportingType.blob + """The type of reporting.""" + + connection_string: str | None = pydantic_Field( + description="The blob reporting connection string for the reporting.", + default=None, + ) + """The blob reporting connection string for the reporting.""" + + container_name: str = pydantic_Field( + description="The container name for reporting", default=None + ) + """The container name for reporting""" + + storage_account_blob_url: str | None = pydantic_Field( + description="The storage account blob url for reporting", default=None + ) + """The storage account blob url for reporting""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the reporting.", default=None + ) + """The base directory for the reporting.""" + + +PipelineReportingConfigTypes = ( + PipelineFileReportingConfig + | PipelineConsoleReportingConfig + | PipelineBlobReportingConfig +) diff --git a/func-app/graphrag/index/config/workflow.py b/func-app/graphrag/index/config/workflow.py new file mode 100644 index 0000000000..c26fef6ca0 --- /dev/null +++ b/func-app/graphrag/index/config/workflow.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineWorkflowReference' model.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +PipelineWorkflowStep = dict[str, Any] +"""Represent a step in a workflow.""" + +PipelineWorkflowConfig = dict[str, Any] +"""Represent a configuration for a workflow.""" + + +class PipelineWorkflowReference(BaseModel): + """Represent a reference to a workflow, and can optionally be the workflow itself.""" + + name: str | None = pydantic_Field(description="Name of the workflow.", default=None) + """Name of the workflow.""" + + steps: list[PipelineWorkflowStep] | None = pydantic_Field( + description="The optional steps for the workflow.", default=None + ) + """The optional steps for the workflow.""" + + config: PipelineWorkflowConfig | None = pydantic_Field( + description="The optional configuration for the workflow.", default=None + ) + """The optional configuration for the workflow.""" diff --git a/func-app/graphrag/index/context.py b/func-app/graphrag/index/context.py new file mode 100644 index 0000000000..cdec0f6292 --- /dev/null +++ b/func-app/graphrag/index/context.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +# isort: skip_file +"""A module containing the 'PipelineRunStats' and 'PipelineRunContext' models.""" + +from dataclasses import dataclass as dc_dataclass +from dataclasses import field + +from .cache import PipelineCache +from graphrag.common.storage.typing import PipelineStorage + + +@dc_dataclass +class PipelineRunStats: + """Pipeline running stats.""" + + total_runtime: float = field(default=0) + """Float representing the total runtime.""" + + num_documents: int = field(default=0) + """Number of documents.""" + + input_load_time: float = field(default=0) + """Float representing the input load time.""" + + workflows: dict[str, dict[str, float]] = field(default_factory=dict) + """A dictionary of workflows.""" + + +@dc_dataclass +class PipelineRunContext: + """Provides the context for the current pipeline run.""" + + stats: PipelineRunStats + storage: PipelineStorage + cache: PipelineCache + + +# TODO: For now, just has the same props available to it +VerbRunContext = PipelineRunContext +"""Provides the context for the current verb run.""" diff --git a/func-app/graphrag/index/context_switch/contextSwitcher.py b/func-app/graphrag/index/context_switch/contextSwitcher.py new file mode 100644 index 0000000000..26c007e428 --- /dev/null +++ b/func-app/graphrag/index/context_switch/contextSwitcher.py @@ -0,0 +1,288 @@ +import asyncio +import os +from io import BytesIO +from pathlib import Path +from typing import cast + +import pandas as pd + +from common.graph_db_client import GraphDBClient +from graphrag.common.progress import ProgressReporter +from graphrag.common.storage import ( + BlobPipelineStorage, + FilePipelineStorage, + PipelineStorage, +) +from graphrag.common.utils.context_utils import get_files_by_contextid +from graphrag.config import ( + GraphRagConfig, + create_graphrag_config, +) +from graphrag.config.enums import StorageType +from graphrag.model.community_report import CommunityReport +from graphrag.model import TextUnit +from graphrag.model.entity import Entity +from graphrag.query.indexer_adapters import ( + read_indexer_entities, + read_indexer_reports, + read_indexer_text_units, +) +from graphrag.model.entity import Entity +from azure.cosmos import CosmosClient, PartitionKey +from graphrag.vector_stores.base import BaseVectorStore +from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType +import logging + +class ContextSwitcher: + """ContextSwitcher class definition.""" + + def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter, + context_id:str, community_level:int , + data_dir: str = None, + optimized_search: bool= False, + use_kusto_community_reports: bool = False,): + + self.root_dir=root_dir + self.config_dir=config_dir + self.data_dir=data_dir + self.reporter=reporter + self.context_id=context_id + self.optimized_search=optimized_search + self.community_level = community_level + self.use_kusto_community_reports = use_kusto_community_reports + logging.info("ContextSwitcher initialized") + + def get_embedding_store(self,config_args): + """Set up the vector store and return it.""" + if not config_args: + config_args = {} + + collection_name = config_args.get( + "query_collection_name", "entity_description_embeddings" + ) + + collection_name += "_" + self.context_id + config_args.update({"collection_name": collection_name}) + + vector_name = config_args.get( + "vector_search_column", "description_embedding" + ) + config_args.update({"vector_name": vector_name}) + config_args.update({"reports_name": f"reports_{self.context_id}"}) + + + config_args.update({"text_units_name": f"text_units_{self.context_id}"}) + + return VectorStoreFactory.get_vector_store( + vector_store_type=VectorStoreType.Kusto, kwargs=config_args + ) + + + + def setup_vector_store(self, + config_args: dict | None = None,) -> BaseVectorStore: + + description_embedding_store = self.get_embedding_store(config_args) + description_embedding_store.connect(**config_args) + + description_embedding_store.setup_entities() + if self.use_kusto_community_reports: + description_embedding_store.setup_reports() + + description_embedding_store.setup_text_units() + + return description_embedding_store + + def _read_config_parameters(self,root: str, config: str | None): + reporter=self.reporter + _root = Path(root) + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) + if not settings_yaml.exists(): + settings_yaml = _root / "settings.yml" + + if settings_yaml.exists(): + reporter.info(f"Reading settings from {settings_yaml}") + with settings_yaml.open( + "rb", + ) as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) + if settings_json.exists(): + reporter.info(f"Reading settings from {settings_json}") + with settings_json.open("rb") as file: + import json + + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + reporter.info("Reading settings from environment variables") + return create_graphrag_config(root_dir=root) + + def activate(self): + """Activate the context.""" + #1. read the context id to fileId mapping. + #2. read the file from storage using common/blob_storage_client.py + #3. GraphDB: use cosmos db client to load data into Cosmos DB. + #4. KustoDB: use Kusto client to load embedding data into Kusto. + data_dir=self.data_dir + root_dir=self.root_dir + config_dir=self.config_dir + reporter=self.reporter + context_id=self.context_id + optimized_search=self.optimized_search + community_level=self.community_level + + def read_paraquet_file(storage: PipelineStorage, path: str): + #create different enum for paraquet storage type + file_data = asyncio.run(storage.get(path, True)) + if file_data is None: + return pd.DataFrame() + return pd.read_parquet(BytesIO(file_data), engine="pyarrow") + + def _configure_paths_and_settings( + data_dir: str | None, + root_dir: str | None, + config_dir: str | None, + ) -> tuple[str, str | None, GraphRagConfig]: + if data_dir is None and root_dir is None: + msg = "Either data_dir or root_dir must be provided." + raise ValueError(msg) + if data_dir is None: + data_dir = _infer_data_dir(cast(str, root_dir)) + config = _create_graphrag_config(root_dir, config_dir) + return data_dir, root_dir, config + + + def _infer_data_dir(root: str) -> str: + output = Path(root) / "output" + # use the latest data-run folder + if output.exists(): + folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True) + if len(folders) > 0: + folder = folders[0] + return str((folder / "artifacts").absolute()) + msg = f"Could not infer data directory from root={root}" + raise ValueError(msg) + + + def _create_graphrag_config( + root: str | None, + config_dir: str | None, + ) -> GraphRagConfig: + """Create a GraphRag configuration.""" + return self._read_config_parameters(root or "./", config_dir) + + ################################################################################ + + + _, _, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + + if(config.storage.type == StorageType.memory): + ValueError("Memory storage is not supported") + if(config.storage.type == StorageType.blob): + if(config.storage.container_name is not None): + input_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, container_name=config.storage.container_name, storage_account_blob_url=config.storage.storage_account_blob_url) + else: + ValueError("Storage type is Blob but container name is invalid") + if(config.storage.type == StorageType.file): + input_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + + data_paths = [] + data_paths = get_files_by_contextid(config, context_id) + final_nodes = pd.DataFrame() + final_community_reports = pd.DataFrame() + final_text_units = pd.DataFrame() + final_relationships = pd.DataFrame() + final_entities = pd.DataFrame() + final_covariates = pd.DataFrame() + graph_db_client=None + + if config.graphdb.enabled: + cosmos_client = CosmosClient( + f"{config.graphdb.cosmos_url}", + f"{config.graphdb.account_key}", + ) + database_name = config.graphdb.username.split("/")[2] + database = cosmos_client.get_database_client(database_name) + graph_name=config.graphdb.username.split("/")[-1]+"-contextid-"+context_id + graph = database.create_container_if_not_exists( + id=graph_name, + partition_key=PartitionKey(path='/category'), + offer_throughput=400 + ) + graph_db_client = GraphDBClient(config.graphdb,context_id) + + description_embedding_store = self.setup_vector_store(config_args=config.embeddings.vector_store) + + for data_path in data_paths: + #check from the config for the ouptut storage type and then read the data from the storage. + + #GraphDB: we may need to make change below to read nodes data from Graph DB + final_nodes = read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet") + final_community_reports = read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet") # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto + final_text_units = read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet") # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. + + if not optimized_search: + final_covariates = read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet") + + final_relationships = read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet") + final_entities = read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet") + + vector_store_args = ( + config.embeddings.vector_store if config.embeddings.vector_store else {} + ) + + reporter.info(f"Vector Store Args: {vector_store_args}") + + if "type" not in vector_store_args: + ValueError("vectore_store.type can't be empty") + + vector_store_type = vector_store_args.get("type") + + if vector_store_type != VectorStoreType.Kusto: + ValueError("Context switching is only supporeted for vectore_store.type=kusto ") + + entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. + reports = read_indexer_reports(final_community_reports, final_nodes, community_level) + text_units = read_indexer_text_units(final_text_units) + + description_embedding_store.load_entities(entities) + if self.use_kusto_community_reports: + raise ValueError("Community reports not supported for kusto.") + #description_embedding_store.load_reports(reports) + + description_embedding_store.load_text_units(text_units) + + if config.graphdb.enabled: + graph_db_client.write_vertices(final_entities) + graph_db_client.write_edges(final_relationships) + + if config.graphdb.enabled: + graph_db_client._client.close() + + def deactivate(self): + """DeActivate the context.""" + + config=self._read_config_parameters(self.root_dir or "./",self.config_dir) + config_args = config.embeddings.vector_store + description_embedding_store = self.get_embedding_store(config_args) + description_embedding_store.connect(**config_args) + description_embedding_store.unload_entities() + + if config.graphdb.enabled: + g_client=GraphDBClient(config.graphdb,self.context_id) + g_client.remove_graph() \ No newline at end of file diff --git a/func-app/graphrag/index/create_pipeline_config.py b/func-app/graphrag/index/create_pipeline_config.py new file mode 100644 index 0000000000..7cf91ec308 --- /dev/null +++ b/func-app/graphrag/index/create_pipeline_config.py @@ -0,0 +1,595 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Default configuration methods definition.""" + +import json +import logging +from pathlib import Path +from .emit.types import TableEmitterType + +from graphrag.config.enums import ( + CacheType, + InputFileType, + ReportingType, + StorageType, + TextEmbeddingTarget, +) +from graphrag.config.models import ( + GraphRagConfig, + TextEmbeddingConfig, +) +from graphrag.index.config.cache import ( + PipelineBlobCacheConfig, + PipelineCacheConfigTypes, + PipelineFileCacheConfig, + PipelineMemoryCacheConfig, + PipelineNoneCacheConfig, +) +from graphrag.index.config.input import ( + PipelineCSVInputConfig, + PipelineInputConfigTypes, + PipelineTextInputConfig, +) +from graphrag.index.config.pipeline import ( + PipelineConfig, +) +from graphrag.index.config.reporting import ( + PipelineBlobReportingConfig, + PipelineConsoleReportingConfig, + PipelineFileReportingConfig, + PipelineReportingConfigTypes, +) +from graphrag.common.config.storage import ( + PipelineBlobStorageConfig, + PipelineFileStorageConfig, + PipelineMemoryStorageConfig, + PipelineStorageConfigTypes, +) +from graphrag.index.config.workflow import ( + PipelineWorkflowReference, +) +from graphrag.index.workflows.default_workflows import ( + create_base_documents, + create_base_entity_graph, + create_base_extracted_entities, + create_base_text_units, + create_final_communities, + create_final_community_reports, + create_final_covariates, + create_final_documents, + create_final_entities, + create_final_nodes, + create_final_relationships, + create_final_text_units, + create_summarized_entities, + join_text_units_to_covariate_ids, + join_text_units_to_entity_ids, + join_text_units_to_relationship_ids, +) + +log = logging.getLogger(__name__) + + +entity_name_embedding = "entity.name" +entity_description_embedding = "entity.description" +relationship_description_embedding = "relationship.description" +document_raw_content_embedding = "document.raw_content" +community_title_embedding = "community.title" +community_summary_embedding = "community.summary" +community_full_content_embedding = "community.full_content" +text_unit_text_embedding = "text_unit.text" + +all_embeddings: set[str] = { + entity_name_embedding, + entity_description_embedding, + relationship_description_embedding, + document_raw_content_embedding, + community_title_embedding, + community_summary_embedding, + community_full_content_embedding, + text_unit_text_embedding, +} +required_embeddings: set[str] = {entity_description_embedding} + + +builtin_document_attributes: set[str] = { + "id", + "source", + "text", + "title", + "timestamp", + "year", + "month", + "day", + "hour", + "minute", + "second", +} + + +def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineConfig: + """Get the default config for the pipeline.""" + # relative to the root_dir + if verbose: + _log_llm_settings(settings) + + skip_workflows = _determine_skip_workflows(settings) + embedded_fields = _get_embedded_fields(settings) + covariates_enabled = ( + settings.claim_extraction.enabled + and create_final_covariates not in skip_workflows + ) + + result = PipelineConfig( + root_dir=settings.root_dir, + input=_get_pipeline_input_config(settings), + reporting=_get_reporting_config(settings), + storage=_get_storage_config(settings), + cache=_get_cache_config(settings), + workflows=[ + *_document_workflows(settings, embedded_fields), + *_text_unit_workflows(settings, covariates_enabled, embedded_fields), + *_graph_workflows(settings, embedded_fields), + *_community_workflows(settings, covariates_enabled, embedded_fields), + *(_covariate_workflows(settings) if covariates_enabled else []), + ], + graphdb_params=settings.graphdb + ) + + # Remove any workflows that were specified to be skipped + log.info("skipping workflows %s", ",".join(skip_workflows)) + result.workflows = [w for w in result.workflows if w.name not in skip_workflows] + return result + + +def _get_embedded_fields(settings: GraphRagConfig) -> set[str]: + match settings.embeddings.target: + case TextEmbeddingTarget.all: + return all_embeddings - {*settings.embeddings.skip} + case TextEmbeddingTarget.required: + return required_embeddings + case _: + msg = f"Unknown embeddings target: {settings.embeddings.target}" + raise ValueError(msg) + + +def _determine_skip_workflows(settings: GraphRagConfig) -> list[str]: + skip_workflows = settings.skip_workflows + if ( + create_final_covariates in skip_workflows + and join_text_units_to_covariate_ids not in skip_workflows + ): + skip_workflows.append(join_text_units_to_covariate_ids) + return skip_workflows + + +def _log_llm_settings(settings: GraphRagConfig) -> None: + log.info( + "Using LLM Config %s", + json.dumps( + {**settings.entity_extraction.llm.model_dump(), "api_key": "*****"}, + indent=4, + ), + ) + log.info( + "Using Embeddings Config %s", + json.dumps( + {**settings.embeddings.llm.model_dump(), "api_key": "*****"}, indent=4 + ), + ) + + +def _document_workflows( + settings: GraphRagConfig, embedded_fields: set[str] +) -> list[PipelineWorkflowReference]: + skip_document_raw_content_embedding = ( + document_raw_content_embedding not in embedded_fields + ) + return [ + PipelineWorkflowReference( + name=create_base_documents, + config={ + "document_attribute_columns": list( + {*(settings.input.document_attribute_columns)} + - builtin_document_attributes + ) + }, + ), + PipelineWorkflowReference( + name=create_final_documents, + config={ + "document_raw_content_embed": _get_embedding_settings( + settings.embeddings, + "document_raw_content", + { + "title_column": "raw_content", + "collection_name": "final_documents_raw_content_embedding", + }, + ), + "skip_raw_content_embedding": skip_document_raw_content_embedding, + }, + ), + ] + + +def _text_unit_workflows( + settings: GraphRagConfig, + covariates_enabled: bool, + embedded_fields: set[str], +) -> list[PipelineWorkflowReference]: + skip_text_unit_embedding = text_unit_text_embedding not in embedded_fields + return [ + PipelineWorkflowReference( + name=create_base_text_units, + config={ + "chunk_by": settings.chunks.group_by_columns, + "text_chunk": { + "strategy": settings.chunks.resolved_strategy( + settings.encoding_model + ) + }, + }, + ), + PipelineWorkflowReference( + name=join_text_units_to_entity_ids, + ), + PipelineWorkflowReference( + name=join_text_units_to_relationship_ids, + ), + *( + [ + PipelineWorkflowReference( + name=join_text_units_to_covariate_ids, + ) + ] + if covariates_enabled + else [] + ), + PipelineWorkflowReference( + name=create_final_text_units, + config={ + "text_unit_text_embed": _get_embedding_settings( + settings.embeddings, + "text_unit_text", + {"title_column": "text", "collection_name": "text_units_embedding"}, + ), + "covariates_enabled": covariates_enabled, + "skip_text_unit_embedding": skip_text_unit_embedding, + }, + ), + ] + + +def _get_embedding_settings( + settings: TextEmbeddingConfig, + embedding_name: str, + vector_store_params: dict | None = None, +) -> dict: + vector_store_settings = settings.vector_store + if vector_store_settings is None: + return {"strategy": settings.resolved_strategy()} + # + # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. + # settings.vector_store.base contains connection information, or may be undefined + # settings.vector_store. contains the specific settings for this embedding + # + strategy = settings.resolved_strategy() # get the default strategy + strategy.update({ + "vector_store": {**vector_store_settings, **(vector_store_params or {})} + }) # update the default strategy with the vector store settings + # This ensures the vector store config is part of the strategy and not the global config + return { + "strategy": strategy, + "embedding_name": embedding_name, + } + + +def _graph_workflows( + settings: GraphRagConfig, embedded_fields: set[str] +) -> list[PipelineWorkflowReference]: + skip_entity_name_embedding = entity_name_embedding not in embedded_fields + skip_entity_description_embedding = ( + entity_description_embedding not in embedded_fields + ) + skip_relationship_description_embedding = ( + relationship_description_embedding not in embedded_fields + ) + return [ + PipelineWorkflowReference( + name=create_base_extracted_entities, + config={ + "graphml_snapshot": settings.snapshots.graphml, + "raw_entity_snapshot": settings.snapshots.raw_entities, + "entity_extract": { + **settings.entity_extraction.parallelization.model_dump(), + "async_mode": settings.entity_extraction.async_mode, + "strategy": settings.entity_extraction.resolved_strategy( + settings.root_dir, settings.encoding_model + ), + "entity_types": settings.entity_extraction.entity_types, + }, + }, + ), + PipelineWorkflowReference( + name=create_summarized_entities, + config={ + "graphml_snapshot": settings.snapshots.graphml, + "summarize_descriptions": { + **settings.summarize_descriptions.parallelization.model_dump(), + "async_mode": settings.summarize_descriptions.async_mode, + "strategy": settings.summarize_descriptions.resolved_strategy( + settings.root_dir, + ), + }, + }, + ), + PipelineWorkflowReference( + name=create_base_entity_graph, + config={ + "graphml_snapshot": settings.snapshots.graphml, + "embed_graph_enabled": settings.embed_graph.enabled, + "cluster_graph": { + "strategy": settings.cluster_graph.resolved_strategy() + }, + "embed_graph": {"strategy": settings.embed_graph.resolved_strategy()}, + }, + ), + PipelineWorkflowReference( + name=create_final_entities, + config={ + "entity_name_embed": _get_embedding_settings( + settings.embeddings, + "entity_name", + { + "title_column": "name", + "collection_name": "entity_name_embeddings", + }, + ), + "entity_name_description_embed": _get_embedding_settings( + settings.embeddings, + "entity_name_description", + { + "title_column": "description", + "collection_name": "entity_description_embeddings", + "vector_name": "vector", + "reports_name": "reports", + }, + ), + "skip_name_embedding": skip_entity_name_embedding, + "skip_description_embedding": skip_entity_description_embedding, + "emitter_type": TableEmitterType.Graphdb, + }, + ), + PipelineWorkflowReference( + name=create_final_relationships, + config={ + "relationship_description_embed": _get_embedding_settings( + settings.embeddings, + "relationship_description", + { + "title_column": "description", + "collection_name": "relationships_description_embeddings", + }, + ), + "skip_description_embedding": skip_relationship_description_embedding, + "emitter_type": TableEmitterType.Graphdb, + }, + ), + PipelineWorkflowReference( + name=create_final_nodes, + config={ + "layout_graph_enabled": settings.umap.enabled, + "snapshot_top_level_nodes": settings.snapshots.top_level_nodes, + }, + ), + ] + + +def _community_workflows( + settings: GraphRagConfig, covariates_enabled: bool, embedded_fields: set[str] +) -> list[PipelineWorkflowReference]: + skip_community_title_embedding = community_title_embedding not in embedded_fields + skip_community_summary_embedding = ( + community_summary_embedding not in embedded_fields + ) + skip_community_full_content_embedding = ( + community_full_content_embedding not in embedded_fields + ) + return [ + PipelineWorkflowReference(name=create_final_communities), + PipelineWorkflowReference( + name=create_final_community_reports, + config={ + "covariates_enabled": covariates_enabled, + "skip_title_embedding": skip_community_title_embedding, + "skip_summary_embedding": skip_community_summary_embedding, + "skip_full_content_embedding": skip_community_full_content_embedding, + "create_community_reports": { + **settings.community_reports.parallelization.model_dump(), + "async_mode": settings.community_reports.async_mode, + "strategy": settings.community_reports.resolved_strategy( + settings.root_dir + ), + }, + "community_report_full_content_embed": _get_embedding_settings( + settings.embeddings, + "community_report_full_content", + { + "title_column": "full_content", + "collection_name": "final_community_reports_full_content_embedding", + }, + ), + "community_report_summary_embed": _get_embedding_settings( + settings.embeddings, + "community_report_summary", + { + "title_column": "summary", + "collection_name": "final_community_reports_summary_embedding", + }, + ), + "community_report_title_embed": _get_embedding_settings( + settings.embeddings, + "community_report_title", + {"title_column": "title"}, + ), + }, + ), + ] + + +def _covariate_workflows( + settings: GraphRagConfig, +) -> list[PipelineWorkflowReference]: + return [ + PipelineWorkflowReference( + name=create_final_covariates, + config={ + "claim_extract": { + **settings.claim_extraction.parallelization.model_dump(), + "strategy": settings.claim_extraction.resolved_strategy( + settings.root_dir, settings.encoding_model + ), + }, + }, + ) + ] + + +def _get_pipeline_input_config( + settings: GraphRagConfig, +) -> PipelineInputConfigTypes: + file_type = settings.input.file_type + match file_type: + case InputFileType.csv: + return PipelineCSVInputConfig( + base_dir=settings.input.base_dir, + file_pattern=settings.input.file_pattern, + encoding=settings.input.encoding, + source_column=settings.input.source_column, + timestamp_column=settings.input.timestamp_column, + timestamp_format=settings.input.timestamp_format, + text_column=settings.input.text_column, + title_column=settings.input.title_column, + type=settings.input.type, + connection_string=settings.input.connection_string, + storage_account_blob_url=settings.input.storage_account_blob_url, + container_name=settings.input.container_name, + ) + case InputFileType.text: + return PipelineTextInputConfig( + base_dir=settings.input.base_dir, + file_pattern=settings.input.file_pattern, + encoding=settings.input.encoding, + type=settings.input.type, + connection_string=settings.input.connection_string, + storage_account_blob_url=settings.input.storage_account_blob_url, + container_name=settings.input.container_name, + ) + case _: + msg = f"Unknown input type: {file_type}" + raise ValueError(msg) + + +def _get_reporting_config( + settings: GraphRagConfig, +) -> PipelineReportingConfigTypes: + """Get the reporting config from the settings.""" + match settings.reporting.type: + case ReportingType.file: + # relative to the root_dir + return PipelineFileReportingConfig(base_dir=settings.reporting.base_dir) + case ReportingType.blob: + connection_string = settings.reporting.connection_string + storage_account_blob_url = settings.reporting.storage_account_blob_url + container_name = settings.reporting.container_name + if container_name is None: + msg = "Container name must be provided for blob reporting." + raise ValueError(msg) + if connection_string is None and storage_account_blob_url is None: + msg = "Connection string or storage account blob url must be provided for blob reporting." + raise ValueError(msg) + return PipelineBlobReportingConfig( + connection_string=connection_string, + container_name=container_name, + base_dir=settings.reporting.base_dir, + storage_account_blob_url=storage_account_blob_url, + ) + case ReportingType.console: + return PipelineConsoleReportingConfig() + case _: + # relative to the root_dir + return PipelineFileReportingConfig(base_dir=settings.reporting.base_dir) + + +def _get_storage_config( + settings: GraphRagConfig, +) -> PipelineStorageConfigTypes: + """Get the storage type from the settings.""" + root_dir = settings.root_dir + match settings.storage.type: + case StorageType.memory: + return PipelineMemoryStorageConfig() + case StorageType.file: + # relative to the root_dir + base_dir = settings.storage.base_dir + if base_dir is None: + msg = "Base directory must be provided for file storage." + raise ValueError(msg) + return PipelineFileStorageConfig(base_dir=str(Path(root_dir) / base_dir)) + case StorageType.blob: + connection_string = settings.storage.connection_string + storage_account_blob_url = settings.storage.storage_account_blob_url + container_name = settings.storage.container_name + if container_name is None: + msg = "Container name must be provided for blob storage." + raise ValueError(msg) + if connection_string is None and storage_account_blob_url is None: + msg = "Connection string or storage account blob url must be provided for blob storage." + raise ValueError(msg) + return PipelineBlobStorageConfig( + connection_string=connection_string, + container_name=container_name, + base_dir=settings.storage.base_dir, + storage_account_blob_url=storage_account_blob_url, + ) + case _: + # relative to the root_dir + base_dir = settings.storage.base_dir + if base_dir is None: + msg = "Base directory must be provided for file storage." + raise ValueError(msg) + return PipelineFileStorageConfig(base_dir=str(Path(root_dir) / base_dir)) + + +def _get_cache_config( + settings: GraphRagConfig, +) -> PipelineCacheConfigTypes: + """Get the cache type from the settings.""" + match settings.cache.type: + case CacheType.memory: + return PipelineMemoryCacheConfig() + case CacheType.file: + # relative to root dir + return PipelineFileCacheConfig(base_dir=settings.cache.base_dir) + case CacheType.none: + return PipelineNoneCacheConfig() + case CacheType.blob: + connection_string = settings.cache.connection_string + storage_account_blob_url = settings.cache.storage_account_blob_url + container_name = settings.cache.container_name + if container_name is None: + msg = "Container name must be provided for blob cache." + raise ValueError(msg) + if connection_string is None and storage_account_blob_url is None: + msg = "Connection string or storage account blob url must be provided for blob cache." + raise ValueError(msg) + return PipelineBlobCacheConfig( + connection_string=connection_string, + container_name=container_name, + base_dir=settings.cache.base_dir, + storage_account_blob_url=storage_account_blob_url, + ) + case _: + # relative to root dir + return PipelineFileCacheConfig(base_dir="./cache") diff --git a/func-app/graphrag/index/emit/__init__.py b/func-app/graphrag/index/emit/__init__.py new file mode 100644 index 0000000000..354989e338 --- /dev/null +++ b/func-app/graphrag/index/emit/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Definitions for emitting pipeline artifacts to storage.""" + +from .csv_table_emitter import CSVTableEmitter +from .factories import create_table_emitter, create_table_emitters +from .json_table_emitter import JsonTableEmitter +from .parquet_table_emitter import ParquetTableEmitter +from .table_emitter import TableEmitter +from .types import TableEmitterType + +__all__ = [ + "CSVTableEmitter", + "JsonTableEmitter", + "ParquetTableEmitter", + "TableEmitter", + "TableEmitterType", + "create_table_emitter", + "create_table_emitters", +] diff --git a/func-app/graphrag/index/emit/csv_table_emitter.py b/func-app/graphrag/index/emit/csv_table_emitter.py new file mode 100644 index 0000000000..0c208d1264 --- /dev/null +++ b/func-app/graphrag/index/emit/csv_table_emitter.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""CSVTableEmitter module.""" + +import logging + +import pandas as pd + +from graphrag.common.storage import PipelineStorage + +from .table_emitter import TableEmitter + +log = logging.getLogger(__name__) + + +class CSVTableEmitter(TableEmitter): + """CSVTableEmitter class.""" + + _storage: PipelineStorage + + def __init__(self, storage: PipelineStorage): + """Create a new CSV Table Emitter.""" + self._storage = storage + + async def emit(self, name: str, data: pd.DataFrame) -> None: + """Emit a dataframe to storage.""" + filename = f"{name}.csv" + log.info("emitting CSV table %s", filename) + await self._storage.set( + filename, + data.to_csv(), + ) diff --git a/func-app/graphrag/index/emit/factories.py b/func-app/graphrag/index/emit/factories.py new file mode 100644 index 0000000000..1c4e218785 --- /dev/null +++ b/func-app/graphrag/index/emit/factories.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Table Emitter Factories.""" + +from graphrag.config.models.graphdb_config import GraphDBConfig +from graphrag.common.storage import PipelineStorage +from graphrag.index.typing import ErrorHandlerFn + +from .csv_table_emitter import CSVTableEmitter +from .json_table_emitter import JsonTableEmitter +from .parquet_table_emitter import ParquetTableEmitter +from .graph_db_emitter import GraphDBEmitter +from .table_emitter import TableEmitter +from .types import TableEmitterType + +def create_table_emitter( + emitter_type: TableEmitterType, storage: PipelineStorage, on_error: ErrorHandlerFn, graphdb_params: GraphDBConfig|None = None, context_id: str|None = None +) -> TableEmitter: + """Create a table emitter based on the specified type.""" + match emitter_type: + case TableEmitterType.Json: + return JsonTableEmitter(storage) + case TableEmitterType.Parquet: + return ParquetTableEmitter(storage, on_error) + case TableEmitterType.CSV: + return CSVTableEmitter(storage) + case TableEmitterType.Graphdb: + return GraphDBEmitter(graphdb_params,context_id) + case _: + msg = f"Unsupported table emitter type: {emitter_type}" + raise ValueError(msg) + + +def create_table_emitters( + emitter_types: list[TableEmitterType], + storage: PipelineStorage, + on_error: ErrorHandlerFn, + graphdb_params: GraphDBConfig|None = None, + context_id: str|None = None, +) -> list[TableEmitter]: + """Create a list of table emitters based on the specified types.""" + return [ + create_table_emitter(emitter_type, storage, on_error, graphdb_params,context_id) + for emitter_type in emitter_types + ] diff --git a/func-app/graphrag/index/emit/graph_db_emitter.py b/func-app/graphrag/index/emit/graph_db_emitter.py new file mode 100644 index 0000000000..d8018ee678 --- /dev/null +++ b/func-app/graphrag/index/emit/graph_db_emitter.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphDBEmitter module.""" + +import pandas as pd +from common.graph_db_client import GraphDBClient +from .table_emitter import TableEmitter +from graphrag.config.models.graphdb_config import GraphDBConfig + +class GraphDBEmitter(TableEmitter): + """Graph DB Emitter.""" + + def __init__(self, graph_db_params: GraphDBConfig|None,context_id: str|None): + self.graph_db_client = GraphDBClient(graph_db_params,context_id) + self.allowed_workflows = ['create_final_entities','create_final_relationships'] + + async def emit(self, name: str, data: pd.DataFrame) -> None: + if name not in self.allowed_workflows: + return + if name == 'create_final_entities': + self.graph_db_client.write_vertices(data) + if name == 'create_final_relationships': + self.graph_db_client.write_edges(data) \ No newline at end of file diff --git a/func-app/graphrag/index/emit/json_table_emitter.py b/func-app/graphrag/index/emit/json_table_emitter.py new file mode 100644 index 0000000000..39f936b781 --- /dev/null +++ b/func-app/graphrag/index/emit/json_table_emitter.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""JsonTableEmitter module.""" + +import logging + +import pandas as pd + +from graphrag.common.storage import PipelineStorage + +from .table_emitter import TableEmitter + +log = logging.getLogger(__name__) + + +class JsonTableEmitter(TableEmitter): + """JsonTableEmitter class.""" + + _storage: PipelineStorage + + def __init__(self, storage: PipelineStorage): + """Create a new Json Table Emitter.""" + self._storage = storage + + async def emit(self, name: str, data: pd.DataFrame) -> None: + """Emit a dataframe to storage.""" + filename = f"{name}.json" + + log.info("emitting JSON table %s", filename) + await self._storage.set( + filename, + data.to_json(orient="records", lines=True, force_ascii=False), + ) diff --git a/func-app/graphrag/index/emit/parquet_table_emitter.py b/func-app/graphrag/index/emit/parquet_table_emitter.py new file mode 100644 index 0000000000..aa6dd38f96 --- /dev/null +++ b/func-app/graphrag/index/emit/parquet_table_emitter.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""ParquetTableEmitter module.""" + +import logging +import traceback + +import pandas as pd +from pyarrow.lib import ArrowInvalid, ArrowTypeError + +from graphrag.common.storage import PipelineStorage +from graphrag.index.typing import ErrorHandlerFn + +from .table_emitter import TableEmitter + +log = logging.getLogger(__name__) + + +class ParquetTableEmitter(TableEmitter): + """ParquetTableEmitter class.""" + + _storage: PipelineStorage + _on_error: ErrorHandlerFn + + def __init__( + self, + storage: PipelineStorage, + on_error: ErrorHandlerFn, + ): + """Create a new Parquet Table Emitter.""" + self._storage = storage + self._on_error = on_error + + async def emit(self, name: str, data: pd.DataFrame) -> None: + """Emit a dataframe to storage.""" + filename = f"{name}.parquet" + log.info("emitting parquet table %s", filename) + try: + await self._storage.set(filename, data.to_parquet()) + except ArrowTypeError as e: + log.exception("Error while emitting parquet table") + self._on_error( + e, + traceback.format_exc(), + None, + ) + except ArrowInvalid as e: + log.exception("Error while emitting parquet table") + self._on_error( + e, + traceback.format_exc(), + None, + ) diff --git a/func-app/graphrag/index/emit/table_emitter.py b/func-app/graphrag/index/emit/table_emitter.py new file mode 100644 index 0000000000..2161eeb523 --- /dev/null +++ b/func-app/graphrag/index/emit/table_emitter.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""TableEmitter protocol for emitting tables to a destination.""" + +from typing import Protocol + +import pandas as pd + + +class TableEmitter(Protocol): + """TableEmitter protocol for emitting tables to a destination.""" + + async def emit(self, name: str, data: pd.DataFrame) -> None: + """Emit a dataframe to storage.""" diff --git a/func-app/graphrag/index/emit/types.py b/func-app/graphrag/index/emit/types.py new file mode 100644 index 0000000000..0b0ff88541 --- /dev/null +++ b/func-app/graphrag/index/emit/types.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Table Emitter Types.""" + +from enum import Enum + + +class TableEmitterType(str, Enum): + """Table Emitter Types.""" + + Json = "json" + Parquet = "parquet" + CSV = "csv" + Graphdb = "graphdb" diff --git a/func-app/graphrag/index/errors.py b/func-app/graphrag/index/errors.py new file mode 100644 index 0000000000..430cf27d0f --- /dev/null +++ b/func-app/graphrag/index/errors.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG indexing error types.""" + + +class NoWorkflowsDefinedError(ValueError): + """Exception for no workflows defined.""" + + def __init__(self): + super().__init__("No workflows defined.") + + +class UndefinedWorkflowError(ValueError): + """Exception for invalid verb input.""" + + def __init__(self): + super().__init__("Workflow name is undefined.") + + +class UnknownWorkflowError(ValueError): + """Exception for invalid verb input.""" + + def __init__(self, name: str): + super().__init__(f"Unknown workflow: {name}") diff --git a/func-app/graphrag/index/graph/__init__.py b/func-app/graphrag/index/graph/__init__.py new file mode 100644 index 0000000000..cb26e59595 --- /dev/null +++ b/func-app/graphrag/index/graph/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph package root.""" diff --git a/func-app/graphrag/index/graph/embedding/__init__.py b/func-app/graphrag/index/graph/embedding/__init__.py new file mode 100644 index 0000000000..0ea2d085f1 --- /dev/null +++ b/func-app/graphrag/index/graph/embedding/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph embedding package root.""" + +from .embedding import NodeEmbeddings, embed_nod2vec + +__all__ = ["NodeEmbeddings", "embed_nod2vec"] diff --git a/func-app/graphrag/index/graph/embedding/embedding.py b/func-app/graphrag/index/graph/embedding/embedding.py new file mode 100644 index 0000000000..267a190f91 --- /dev/null +++ b/func-app/graphrag/index/graph/embedding/embedding.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utilities to generate graph embeddings.""" + +from dataclasses import dataclass + +import graspologic as gc +import networkx as nx +import numpy as np + + +@dataclass +class NodeEmbeddings: + """Node embeddings class definition.""" + + nodes: list[str] + embeddings: np.ndarray + + +def embed_nod2vec( + graph: nx.Graph | nx.DiGraph, + dimensions: int = 1536, + num_walks: int = 10, + walk_length: int = 40, + window_size: int = 2, + iterations: int = 3, + random_seed: int = 86, +) -> NodeEmbeddings: + """Generate node embeddings using Node2Vec.""" + # generate embedding + lcc_tensors = gc.embed.node2vec_embed( # type: ignore + graph=graph, + dimensions=dimensions, + window_size=window_size, + iterations=iterations, + num_walks=num_walks, + walk_length=walk_length, + random_seed=random_seed, + ) + return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1]) diff --git a/func-app/graphrag/index/graph/extractors/__init__.py b/func-app/graphrag/index/graph/extractors/__init__.py new file mode 100644 index 0000000000..9168d5e207 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph extractors package root.""" + +from .claims import CLAIM_EXTRACTION_PROMPT, ClaimExtractor +from .community_reports import ( + COMMUNITY_REPORT_PROMPT, + CommunityReportsExtractor, +) +from .graph import GraphExtractionResult, GraphExtractor + +__all__ = [ + "CLAIM_EXTRACTION_PROMPT", + "COMMUNITY_REPORT_PROMPT", + "ClaimExtractor", + "CommunityReportsExtractor", + "GraphExtractionResult", + "GraphExtractor", +] diff --git a/func-app/graphrag/index/graph/extractors/claims/__init__.py b/func-app/graphrag/index/graph/extractors/claims/__init__.py new file mode 100644 index 0000000000..3977c8ff83 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/claims/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph extractors claims package root.""" + +from .claim_extractor import ClaimExtractor +from .prompts import CLAIM_EXTRACTION_PROMPT + +__all__ = ["CLAIM_EXTRACTION_PROMPT", "ClaimExtractor"] diff --git a/func-app/graphrag/index/graph/extractors/claims/claim_extractor.py b/func-app/graphrag/index/graph/extractors/claims/claim_extractor.py new file mode 100644 index 0000000000..c7e76d5067 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/claims/claim_extractor.py @@ -0,0 +1,248 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'ClaimExtractorResult' and 'ClaimExtractor' models.""" + +import logging +import traceback +from dataclasses import dataclass +from typing import Any + +import tiktoken + +import graphrag.config.defaults as defs +from graphrag.index.typing import ErrorHandlerFn +from graphrag.llm import CompletionLLM + +from .prompts import ( + CLAIM_EXTRACTION_PROMPT, + CONTINUE_PROMPT, + LOOP_PROMPT, +) + +DEFAULT_TUPLE_DELIMITER = "<|>" +DEFAULT_RECORD_DELIMITER = "##" +DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" +log = logging.getLogger(__name__) + + +@dataclass +class ClaimExtractorResult: + """Claim extractor result class definition.""" + + output: list[dict] + source_docs: dict[str, Any] + + +class ClaimExtractor: + """Claim extractor class definition.""" + + _llm: CompletionLLM + _extraction_prompt: str + _summary_prompt: str + _output_formatter_prompt: str + _input_text_key: str + _input_entity_spec_key: str + _input_claim_description_key: str + _tuple_delimiter_key: str + _record_delimiter_key: str + _completion_delimiter_key: str + _max_gleanings: int + _on_error: ErrorHandlerFn + + def __init__( + self, + llm_invoker: CompletionLLM, + extraction_prompt: str | None = None, + input_text_key: str | None = None, + input_entity_spec_key: str | None = None, + input_claim_description_key: str | None = None, + input_resolved_entities_key: str | None = None, + tuple_delimiter_key: str | None = None, + record_delimiter_key: str | None = None, + completion_delimiter_key: str | None = None, + encoding_model: str | None = None, + max_gleanings: int | None = None, + on_error: ErrorHandlerFn | None = None, + ): + """Init method definition.""" + self._llm = llm_invoker + self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT + self._input_text_key = input_text_key or "input_text" + self._input_entity_spec_key = input_entity_spec_key or "entity_specs" + self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" + self._record_delimiter_key = record_delimiter_key or "record_delimiter" + self._completion_delimiter_key = ( + completion_delimiter_key or "completion_delimiter" + ) + self._input_claim_description_key = ( + input_claim_description_key or "claim_description" + ) + self._input_resolved_entities_key = ( + input_resolved_entities_key or "resolved_entities" + ) + self._max_gleanings = ( + max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS + ) + self._on_error = on_error or (lambda _e, _s, _d: None) + + # Construct the looping arguments + encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") + yes = encoding.encode("YES") + no = encoding.encode("NO") + self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} + + async def __call__( + self, inputs: dict[str, Any], prompt_variables: dict | None = None + ) -> ClaimExtractorResult: + """Call method definition.""" + if prompt_variables is None: + prompt_variables = {} + texts = inputs[self._input_text_key] + entity_spec = str(inputs[self._input_entity_spec_key]) + claim_description = inputs[self._input_claim_description_key] + resolved_entities = inputs.get(self._input_resolved_entities_key, {}) + source_doc_map = {} + + prompt_args = { + self._input_entity_spec_key: entity_spec, + self._input_claim_description_key: claim_description, + self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) + or DEFAULT_TUPLE_DELIMITER, + self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) + or DEFAULT_RECORD_DELIMITER, + self._completion_delimiter_key: prompt_variables.get( + self._completion_delimiter_key + ) + or DEFAULT_COMPLETION_DELIMITER, + } + + all_claims: list[dict] = [] + for doc_index, text in enumerate(texts): + document_id = f"d{doc_index}" + try: + claims = await self._process_document(prompt_args, text, doc_index) + all_claims += [ + self._clean_claim(c, document_id, resolved_entities) for c in claims + ] + source_doc_map[document_id] = text + except Exception as e: + log.exception("error extracting claim") + self._on_error( + e, + traceback.format_exc(), + {"doc_index": doc_index, "text": text}, + ) + continue + + return ClaimExtractorResult( + output=all_claims, + source_docs=source_doc_map, + ) + + def _clean_claim( + self, claim: dict, document_id: str, resolved_entities: dict + ) -> dict: + # clean the parsed claims to remove any claims with status = False + obj = claim.get("object_id", claim.get("object")) + subject = claim.get("subject_id", claim.get("subject")) + + # If subject or object in resolved entities, then replace with resolved entity + obj = resolved_entities.get(obj, obj) + subject = resolved_entities.get(subject, subject) + claim["object_id"] = obj + claim["subject_id"] = subject + claim["doc_id"] = document_id + return claim + + async def _process_document( + self, prompt_args: dict, doc, doc_index: int + ) -> list[dict]: + record_delimiter = prompt_args.get( + self._record_delimiter_key, DEFAULT_RECORD_DELIMITER + ) + completion_delimiter = prompt_args.get( + self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER + ) + + response = await self._llm( + self._extraction_prompt, + variables={ + self._input_text_key: doc, + **prompt_args, + }, + ) + results = response.output or "" + claims = results.strip().removesuffix(completion_delimiter) + + # Repeat to ensure we maximize entity count + for i in range(self._max_gleanings): + response = await self._llm( + CONTINUE_PROMPT, + name=f"extract-continuation-{i}", + history=response.history, + ) + extension = response.output or "" + claims += record_delimiter + extension.strip().removesuffix( + completion_delimiter + ) + + # If this isn't the last loop, check to see if we should continue + if i >= self._max_gleanings - 1: + break + + response = await self._llm( + LOOP_PROMPT, + name=f"extract-loopcheck-{i}", + history=response.history, + model_parameters=self._loop_args, + ) + if response.output != "YES": + break + + result = self._parse_claim_tuples(results, prompt_args) + for r in result: + r["doc_id"] = f"{doc_index}" + return result + + def _parse_claim_tuples( + self, claims: str, prompt_variables: dict + ) -> list[dict[str, Any]]: + """Parse claim tuples.""" + record_delimiter = prompt_variables.get( + self._record_delimiter_key, DEFAULT_RECORD_DELIMITER + ) + completion_delimiter = prompt_variables.get( + self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER + ) + tuple_delimiter = prompt_variables.get( + self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER + ) + + def pull_field(index: int, fields: list[str]) -> str | None: + return fields[index].strip() if len(fields) > index else None + + result: list[dict[str, Any]] = [] + claims_values = ( + claims.strip().removesuffix(completion_delimiter).split(record_delimiter) + ) + for claim in claims_values: + claim = claim.strip().removeprefix("(").removesuffix(")") + + # Ignore the completion delimiter + if claim == completion_delimiter: + continue + + claim_fields = claim.split(tuple_delimiter) + result.append({ + "subject_id": pull_field(0, claim_fields), + "object_id": pull_field(1, claim_fields), + "type": pull_field(2, claim_fields), + "status": pull_field(3, claim_fields), + "start_date": pull_field(4, claim_fields), + "end_date": pull_field(5, claim_fields), + "description": pull_field(6, claim_fields), + "source_text": pull_field(7, claim_fields), + "doc_id": pull_field(8, claim_fields), + }) + return result diff --git a/func-app/graphrag/index/graph/extractors/claims/prompts.py b/func-app/graphrag/index/graph/extractors/claims/prompts.py new file mode 100644 index 0000000000..05b3153c20 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/claims/prompts.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing prompts definition.""" + +CLAIM_EXTRACTION_PROMPT = """ +-Target activity- +You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document. + +-Goal- +Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities. + +-Steps- +1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types. +2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim. +For each claim, extract the following information: +- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1. +- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**. +- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type +- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified. +- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references. +- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**. +- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim. + +Format each claim as ({tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in English as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +-Examples- +Example 1: +Entity specification: organization +Claim description: red flags associated with an entity +Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. +Output: + +(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +{completion_delimiter} + +Example 2: +Entity specification: Company A, Person C +Claim description: red flags associated with an entity +Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. +Output: + +(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +{record_delimiter} +(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015) +{completion_delimiter} + +-Real Data- +Use the following input for your answer. +Entity specification: {entity_specs} +Claim description: {claim_description} +Text: {input_text} +Output:""" + + +CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n" +LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n" diff --git a/func-app/graphrag/index/graph/extractors/community_reports/__init__.py b/func-app/graphrag/index/graph/extractors/community_reports/__init__.py new file mode 100644 index 0000000000..599f56d60f --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine community reports package root.""" + +import graphrag.index.graph.extractors.community_reports.schemas as schemas + +from .build_mixed_context import build_mixed_context +from .community_reports_extractor import CommunityReportsExtractor +from .prep_community_report_context import prep_community_report_context +from .prompts import COMMUNITY_REPORT_PROMPT +from .sort_context import sort_context +from .utils import ( + filter_claims_to_nodes, + filter_edges_to_nodes, + filter_nodes_to_level, + get_levels, + set_context_exceeds_flag, + set_context_size, +) + +__all__ = [ + "COMMUNITY_REPORT_PROMPT", + "CommunityReportsExtractor", + "build_mixed_context", + "filter_claims_to_nodes", + "filter_edges_to_nodes", + "filter_nodes_to_level", + "get_levels", + "prep_community_report_context", + "schemas", + "set_context_exceeds_flag", + "set_context_size", + "sort_context", +] diff --git a/func-app/graphrag/index/graph/extractors/community_reports/build_mixed_context.py b/func-app/graphrag/index/graph/extractors/community_reports/build_mixed_context.py new file mode 100644 index 0000000000..ad9e2a8447 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/build_mixed_context.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""A module containing the build_mixed_context method definition.""" + +import pandas as pd + +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.query.llm.text_utils import num_tokens + +from .sort_context import sort_context + + +def build_mixed_context(context: list[dict], max_tokens: int) -> str: + """ + Build parent context by concatenating all sub-communities' contexts. + + If the context exceeds the limit, we use sub-community reports instead. + """ + sorted_context = sorted( + context, key=lambda x: x[schemas.CONTEXT_SIZE], reverse=True + ) + + # replace local context with sub-community reports, starting from the biggest sub-community + substitute_reports = [] + final_local_contexts = [] + exceeded_limit = True + context_string = "" + + for idx, sub_community_context in enumerate(sorted_context): + if exceeded_limit: + if sub_community_context[schemas.FULL_CONTENT]: + substitute_reports.append({ + schemas.COMMUNITY_ID: sub_community_context[schemas.SUB_COMMUNITY], + schemas.FULL_CONTENT: sub_community_context[schemas.FULL_CONTENT], + }) + else: + # this sub-community has no report, so we will use its local context + final_local_contexts.extend(sub_community_context[schemas.ALL_CONTEXT]) + continue + + # add local context for the remaining sub-communities + remaining_local_context = [] + for rid in range(idx + 1, len(sorted_context)): + remaining_local_context.extend(sorted_context[rid][schemas.ALL_CONTEXT]) + new_context_string = sort_context( + local_context=remaining_local_context + final_local_contexts, + sub_community_reports=substitute_reports, + ) + if num_tokens(new_context_string) <= max_tokens: + exceeded_limit = False + context_string = new_context_string + break + + if exceeded_limit: + # if all sub-community reports exceed the limit, we add reports until context is full + substitute_reports = [] + for sub_community_context in sorted_context: + substitute_reports.append({ + schemas.COMMUNITY_ID: sub_community_context[schemas.SUB_COMMUNITY], + schemas.FULL_CONTENT: sub_community_context[schemas.FULL_CONTENT], + }) + new_context_string = pd.DataFrame(substitute_reports).to_csv( + index=False, sep="," + ) + if num_tokens(new_context_string) > max_tokens: + break + + context_string = new_context_string + return context_string diff --git a/func-app/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py b/func-app/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py new file mode 100644 index 0000000000..309336fee7 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'CommunityReportsResult' and 'CommunityReportsExtractor' models.""" + +import logging +import traceback +from dataclasses import dataclass +from typing import Any + +from graphrag.index.typing import ErrorHandlerFn +from graphrag.index.utils import dict_has_keys_with_types +from graphrag.llm import CompletionLLM + +from .prompts import COMMUNITY_REPORT_PROMPT + +log = logging.getLogger(__name__) + + +@dataclass +class CommunityReportsResult: + """Community reports result class definition.""" + + output: str + structured_output: dict + + +class CommunityReportsExtractor: + """Community reports extractor class definition.""" + + _llm: CompletionLLM + _input_text_key: str + _extraction_prompt: str + _output_formatter_prompt: str + _on_error: ErrorHandlerFn + _max_report_length: int + + def __init__( + self, + llm_invoker: CompletionLLM, + input_text_key: str | None = None, + extraction_prompt: str | None = None, + on_error: ErrorHandlerFn | None = None, + max_report_length: int | None = None, + ): + """Init method definition.""" + self._llm = llm_invoker + self._input_text_key = input_text_key or "input_text" + self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT + self._on_error = on_error or (lambda _e, _s, _d: None) + self._max_report_length = max_report_length or 1500 + + async def __call__(self, inputs: dict[str, Any]): + """Call method definition.""" + output = None + try: + response = ( + await self._llm( + self._extraction_prompt, + json=True, + name="create_community_report", + variables={self._input_text_key: inputs[self._input_text_key]}, + is_response_valid=lambda x: dict_has_keys_with_types( + x, + [ + ("title", str), + ("summary", str), + ("findings", list), + ("rating", float), + ("rating_explanation", str), + ], + ), + model_parameters={"max_tokens": self._max_report_length}, + ) + or {} + ) + output = response.json or {} + except Exception as e: + log.exception("error generating community report") + self._on_error(e, traceback.format_exc(), None) + output = {} + + text_output = self._get_text_output(output) + return CommunityReportsResult( + structured_output=output, + output=text_output, + ) + + def _get_text_output(self, parsed_output: dict) -> str: + title = parsed_output.get("title", "Report") + summary = parsed_output.get("summary", "") + findings = parsed_output.get("findings", []) + + def finding_summary(finding: dict): + if isinstance(finding, str): + return finding + return finding.get("summary") + + def finding_explanation(finding: dict): + if isinstance(finding, str): + return "" + return finding.get("explanation") + + report_sections = "\n\n".join( + f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings + ) + return f"# {title}\n\n{summary}\n\n{report_sections}" diff --git a/func-app/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py b/func-app/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py new file mode 100644 index 0000000000..2ec4222024 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py @@ -0,0 +1,181 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_community_reports and load_strategy methods definition.""" + +import logging +from typing import cast + +import pandas as pd + +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.index.utils.dataframes import ( + antijoin, + drop_columns, + join, + select, + transform_series, + union, + where_column_equals, +) + +from .build_mixed_context import build_mixed_context +from .sort_context import sort_context +from .utils import set_context_size + +log = logging.getLogger(__name__) + + +def prep_community_report_context( + report_df: pd.DataFrame | None, + community_hierarchy_df: pd.DataFrame, + local_context_df: pd.DataFrame, + level: int | str, + max_tokens: int, +) -> pd.DataFrame: + """ + Prep context for each community in a given level. + + For each community: + - Check if local context fits within the limit, if yes use local context + - If local context exceeds the limit, iteratively replace local context with sub-community reports, starting from the biggest sub-community + """ + if report_df is None: + report_df = pd.DataFrame() + + level = int(level) + level_context_df = _at_level(level, local_context_df) + valid_context_df = _within_context(level_context_df) + invalid_context_df = _exceeding_context(level_context_df) + + # there is no report to substitute with, so we just trim the local context of the invalid context records + # this case should only happen at the bottom level of the community hierarchy where there are no sub-communities + if invalid_context_df.empty: + return valid_context_df + + if report_df.empty: + invalid_context_df[schemas.CONTEXT_STRING] = _sort_and_trim_context( + invalid_context_df, max_tokens + ) + set_context_size(invalid_context_df) + invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = 0 + return union(valid_context_df, invalid_context_df) + + level_context_df = _antijoin_reports(level_context_df, report_df) + + # for each invalid context, we will try to substitute with sub-community reports + # first get local context and report (if available) for each sub-community + sub_context_df = _get_subcontext_df(level + 1, report_df, local_context_df) + community_df = _get_community_df( + level, invalid_context_df, sub_context_df, community_hierarchy_df, max_tokens + ) + + # handle any remaining invalid records that can't be subsituted with sub-community reports + # this should be rare, but if it happens, we will just trim the local context to fit the limit + remaining_df = _antijoin_reports(invalid_context_df, community_df) + remaining_df[schemas.CONTEXT_STRING] = _sort_and_trim_context( + remaining_df, max_tokens + ) + + result = union(valid_context_df, community_df, remaining_df) + set_context_size(result) + result[schemas.CONTEXT_EXCEED_FLAG] = 0 + return result + + +def _drop_community_level(df: pd.DataFrame) -> pd.DataFrame: + """Drop the community level column from the dataframe.""" + return drop_columns(df, schemas.COMMUNITY_LEVEL) + + +def _at_level(level: int, df: pd.DataFrame) -> pd.DataFrame: + """Return records at the given level.""" + return where_column_equals(df, schemas.COMMUNITY_LEVEL, level) + + +def _exceeding_context(df: pd.DataFrame) -> pd.DataFrame: + """Return records where the context exceeds the limit.""" + return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 1) + + +def _within_context(df: pd.DataFrame) -> pd.DataFrame: + """Return records where the context is within the limit.""" + return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 0) + + +def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame: + """Return records in df that are not in reports.""" + return antijoin(df, reports, schemas.NODE_COMMUNITY) + + +def _sort_and_trim_context(df: pd.DataFrame, max_tokens: int) -> pd.Series: + """Sort and trim context to fit the limit.""" + series = cast(pd.Series, df[schemas.ALL_CONTEXT]) + return transform_series(series, lambda x: sort_context(x, max_tokens=max_tokens)) + + +def _build_mixed_context(df: pd.DataFrame, max_tokens: int) -> pd.Series: + """Sort and trim context to fit the limit.""" + series = cast(pd.Series, df[schemas.ALL_CONTEXT]) + return transform_series( + series, lambda x: build_mixed_context(x, max_tokens=max_tokens) + ) + + +def _get_subcontext_df( + level: int, report_df: pd.DataFrame, local_context_df: pd.DataFrame +) -> pd.DataFrame: + """Get sub-community context for each community.""" + sub_report_df = _drop_community_level(_at_level(level, report_df)) + sub_context_df = _at_level(level, local_context_df) + sub_context_df = join(sub_context_df, sub_report_df, schemas.NODE_COMMUNITY) + sub_context_df.rename( + columns={schemas.NODE_COMMUNITY: schemas.SUB_COMMUNITY}, inplace=True + ) + return sub_context_df + + +def _get_community_df( + level: int, + invalid_context_df: pd.DataFrame, + sub_context_df: pd.DataFrame, + community_hierarchy_df: pd.DataFrame, + max_tokens: int, +) -> pd.DataFrame: + """Get community context for each community.""" + # collect all sub communities' contexts for each community + community_df = _drop_community_level(_at_level(level, community_hierarchy_df)) + invalid_community_ids = select(invalid_context_df, schemas.NODE_COMMUNITY) + subcontext_selection = select( + sub_context_df, + schemas.SUB_COMMUNITY, + schemas.FULL_CONTENT, + schemas.ALL_CONTEXT, + schemas.CONTEXT_SIZE, + ) + + invalid_communities = join( + community_df, invalid_community_ids, schemas.NODE_COMMUNITY, "inner" + ) + community_df = join( + invalid_communities, subcontext_selection, schemas.SUB_COMMUNITY + ) + community_df[schemas.ALL_CONTEXT] = community_df.apply( + lambda x: { + schemas.SUB_COMMUNITY: x[schemas.SUB_COMMUNITY], + schemas.ALL_CONTEXT: x[schemas.ALL_CONTEXT], + schemas.FULL_CONTENT: x[schemas.FULL_CONTENT], + schemas.CONTEXT_SIZE: x[schemas.CONTEXT_SIZE], + }, + axis=1, + ) + community_df = ( + community_df.groupby(schemas.NODE_COMMUNITY) + .agg({schemas.ALL_CONTEXT: list}) + .reset_index() + ) + community_df[schemas.CONTEXT_STRING] = _build_mixed_context( + community_df, max_tokens + ) + community_df[schemas.COMMUNITY_LEVEL] = level + return community_df diff --git a/func-app/graphrag/index/graph/extractors/community_reports/prompts.py b/func-app/graphrag/index/graph/extractors/community_reports/prompts.py new file mode 100644 index 0000000000..35ca38bc8b --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/prompts.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""A file containing prompts definition.""" + +COMMUNITY_REPORT_PROMPT = """ +You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network. + +# Goal +Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims. + +# Report Structure + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. +- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format: + {{ + "title": , + "summary": , + "rating": , + "rating_explanation": , + "findings": [ + {{ + "summary":, + "explanation": + }}, + {{ + "summary":, + "explanation": + }} + ] + }} + +# Grounding Rules + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." + +where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +# Example Input +----------- +Text: + +Entities + +id,entity,description +5,VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March +6,HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza + +Relationships + +id,source,target,description +37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March +38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza +39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza +40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza +41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march +43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March + +Output: +{{ + "title": "Verdant Oasis Plaza and Unity March", + "summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.", + "rating": 5.0, + "rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.", + "findings": [ + {{ + "summary": "Verdant Oasis Plaza as the central location", + "explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes. [Data: Entities (5), Relationships (37, 38, 39, 40, 41,+more)]" + }}, + {{ + "summary": "Harmony Assembly's role in the community", + "explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community. [Data: Entities(6), Relationships (38, 43)]" + }}, + {{ + "summary": "Unity March as a significant event", + "explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community. [Data: Relationships (39)]" + }}, + {{ + "summary": "Role of Tribune Spotlight", + "explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved. [Data: Relationships (40)]" + }} + ] +}} + + +# Real Data + +Use the following text for your answer. Do not make anything up in your answer. + +Text: +{input_text} + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. +- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format: + {{ + "title": , + "summary": , + "rating": , + "rating_explanation": , + "findings": [ + {{ + "summary":, + "explanation": + }}, + {{ + "summary":, + "explanation": + }} + ] + }} + +# Grounding Rules + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." + +where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + +Output:""" diff --git a/func-app/graphrag/index/graph/extractors/community_reports/schemas.py b/func-app/graphrag/index/graph/extractors/community_reports/schemas.py new file mode 100644 index 0000000000..8e89e0273c --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/schemas.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Common field name definitions for community reports.""" + +# POST-PREP NODE TABLE SCHEMA +NODE_ID = "human_readable_id" +NODE_NAME = "title" +NODE_DESCRIPTION = "description" +NODE_DEGREE = "degree" +NODE_DETAILS = "node_details" +NODE_COMMUNITY = "community" +NODE_LEVEL = "level" + +# POST-PREP EDGE TABLE SCHEMA +EDGE_ID = "human_readable_id" +EDGE_SOURCE = "source" +EDGE_TARGET = "target" +EDGE_DESCRIPTION = "description" +EDGE_DEGREE = "rank" +EDGE_DETAILS = "edge_details" +EDGE_WEIGHT = "weight" + +# POST-PREP CLAIM TABLE SCHEMA +CLAIM_ID = "human_readable_id" +CLAIM_SUBJECT = "subject_id" +CLAIM_TYPE = "type" +CLAIM_STATUS = "status" +CLAIM_DESCRIPTION = "description" +CLAIM_DETAILS = "claim_details" + +# COMMUNITY HIERARCHY TABLE SCHEMA +SUB_COMMUNITY = "sub_communitty" +SUB_COMMUNITY_SIZE = "sub_community_size" +COMMUNITY_LEVEL = "level" + +# COMMUNITY CONTEXT TABLE SCHEMA +ALL_CONTEXT = "all_context" +CONTEXT_STRING = "context_string" +CONTEXT_SIZE = "context_size" +CONTEXT_EXCEED_FLAG = "context_exceed_limit" + +# COMMUNITY REPORT TABLE SCHEMA +REPORT_ID = "id" +COMMUNITY_ID = "id" +COMMUNITY_LEVEL = "level" +TITLE = "title" +SUMMARY = "summary" +FINDINGS = "findings" +RATING = "rank" +EXPLANATION = "rating_explanation" +FULL_CONTENT = "full_content" +FULL_CONTENT_JSON = "full_content_json" diff --git a/func-app/graphrag/index/graph/extractors/community_reports/sort_context.py b/func-app/graphrag/index/graph/extractors/community_reports/sort_context.py new file mode 100644 index 0000000000..811cb7e95c --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/sort_context.py @@ -0,0 +1,156 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Sort context by degree in descending order.""" + +import pandas as pd + +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.query.llm.text_utils import num_tokens + + +def sort_context( + local_context: list[dict], + sub_community_reports: list[dict] | None = None, + max_tokens: int | None = None, + node_id_column: str = schemas.NODE_ID, + node_name_column: str = schemas.NODE_NAME, + node_details_column: str = schemas.NODE_DETAILS, + edge_id_column: str = schemas.EDGE_ID, + edge_details_column: str = schemas.EDGE_DETAILS, + edge_degree_column: str = schemas.EDGE_DEGREE, + edge_source_column: str = schemas.EDGE_SOURCE, + edge_target_column: str = schemas.EDGE_TARGET, + claim_id_column: str = schemas.CLAIM_ID, + claim_details_column: str = schemas.CLAIM_DETAILS, + community_id_column: str = schemas.COMMUNITY_ID, +) -> str: + """Sort context by degree in descending order. + + If max tokens is provided, we will return the context string that fits within the token limit. + """ + + def _get_context_string( + entities: list[dict], + edges: list[dict], + claims: list[dict], + sub_community_reports: list[dict] | None = None, + ) -> str: + """Concatenate structured data into a context string.""" + contexts = [] + if sub_community_reports: + sub_community_reports = [ + report + for report in sub_community_reports + if community_id_column in report + and report[community_id_column] + and str(report[community_id_column]).strip() != "" + ] + report_df = pd.DataFrame(sub_community_reports).drop_duplicates() + if not report_df.empty: + if report_df[community_id_column].dtype == float: + report_df[community_id_column] = report_df[ + community_id_column + ].astype(int) + report_string = ( + f"----Reports-----\n{report_df.to_csv(index=False, sep=',')}" + ) + contexts.append(report_string) + + entities = [ + entity + for entity in entities + if node_id_column in entity + and entity[node_id_column] + and str(entity[node_id_column]).strip() != "" + ] + entity_df = pd.DataFrame(entities).drop_duplicates() + if not entity_df.empty: + if entity_df[node_id_column].dtype == float: + entity_df[node_id_column] = entity_df[node_id_column].astype(int) + entity_string = ( + f"-----Entities-----\n{entity_df.to_csv(index=False, sep=',')}" + ) + contexts.append(entity_string) + + if claims and len(claims) > 0: + claims = [ + claim + for claim in claims + if claim_id_column in claim + and claim[claim_id_column] + and str(claim[claim_id_column]).strip() != "" + ] + claim_df = pd.DataFrame(claims).drop_duplicates() + if not claim_df.empty: + if claim_df[claim_id_column].dtype == float: + claim_df[claim_id_column] = claim_df[claim_id_column].astype(int) + claim_string = ( + f"-----Claims-----\n{claim_df.to_csv(index=False, sep=',')}" + ) + contexts.append(claim_string) + + edges = [ + edge + for edge in edges + if edge_id_column in edge + and edge[edge_id_column] + and str(edge[edge_id_column]).strip() != "" + ] + edge_df = pd.DataFrame(edges).drop_duplicates() + if not edge_df.empty: + if edge_df[edge_id_column].dtype == float: + edge_df[edge_id_column] = edge_df[edge_id_column].astype(int) + edge_string = ( + f"-----Relationships-----\n{edge_df.to_csv(index=False, sep=',')}" + ) + contexts.append(edge_string) + + return "\n\n".join(contexts) + + # sort node details by degree in descending order + edges = [] + node_details = {} + claim_details = {} + + for record in local_context: + node_name = record[node_name_column] + record_edges = record.get(edge_details_column, []) + record_edges = [e for e in record_edges if not pd.isna(e)] + record_node_details = record[node_details_column] + record_claims = record.get(claim_details_column, []) + record_claims = [c for c in record_claims if not pd.isna(c)] + + edges.extend(record_edges) + node_details[node_name] = record_node_details + claim_details[node_name] = record_claims + + edges = [edge for edge in edges if isinstance(edge, dict)] + edges = sorted(edges, key=lambda x: x[edge_degree_column], reverse=True) + + sorted_edges = [] + sorted_nodes = [] + sorted_claims = [] + context_string = "" + for edge in edges: + source_details = node_details.get(edge[edge_source_column], {}) + target_details = node_details.get(edge[edge_target_column], {}) + sorted_nodes.extend([source_details, target_details]) + sorted_edges.append(edge) + source_claims = claim_details.get(edge[edge_source_column], []) + target_claims = claim_details.get(edge[edge_target_column], []) + sorted_claims.extend(source_claims if source_claims else []) + sorted_claims.extend(target_claims if source_claims else []) + if max_tokens: + new_context_string = _get_context_string( + sorted_nodes, sorted_edges, sorted_claims, sub_community_reports + ) + if num_tokens(context_string) > max_tokens: + break + context_string = new_context_string + + if context_string == "": + return _get_context_string( + sorted_nodes, sorted_edges, sorted_claims, sub_community_reports + ) + + return context_string diff --git a/func-app/graphrag/index/graph/extractors/community_reports/utils.py b/func-app/graphrag/index/graph/extractors/community_reports/utils.py new file mode 100644 index 0000000000..b5fc9af9b8 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing community report generation utilities.""" + +from typing import cast + +import pandas as pd + +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.query.llm.text_utils import num_tokens + + +def set_context_size(df: pd.DataFrame) -> None: + """Measure the number of tokens in the context.""" + df[schemas.CONTEXT_SIZE] = df[schemas.CONTEXT_STRING].apply(lambda x: num_tokens(x)) + + +def set_context_exceeds_flag(df: pd.DataFrame, max_tokens: int) -> None: + """Set a flag to indicate if the context exceeds the limit.""" + df[schemas.CONTEXT_EXCEED_FLAG] = df[schemas.CONTEXT_SIZE].apply( + lambda x: x > max_tokens + ) + + +def get_levels(df: pd.DataFrame, level_column: str = schemas.NODE_LEVEL) -> list[int]: + """Get the levels of the communities.""" + result = sorted(df[level_column].fillna(-1).unique().tolist(), reverse=True) + return [r for r in result if r != -1] + + +def filter_nodes_to_level(node_df: pd.DataFrame, level: int) -> pd.DataFrame: + """Filter nodes to level.""" + return cast(pd.DataFrame, node_df[node_df[schemas.NODE_LEVEL] == level]) + + +def filter_edges_to_nodes(edge_df: pd.DataFrame, nodes: list[str]) -> pd.DataFrame: + """Filter edges to nodes.""" + return cast( + pd.DataFrame, + edge_df[ + edge_df[schemas.EDGE_SOURCE].isin(nodes) + & edge_df[schemas.EDGE_TARGET].isin(nodes) + ], + ) + + +def filter_claims_to_nodes(claims_df: pd.DataFrame, nodes: list[str]) -> pd.DataFrame: + """Filter edges to nodes.""" + return cast( + pd.DataFrame, + claims_df[claims_df[schemas.CLAIM_SUBJECT].isin(nodes)], + ) diff --git a/func-app/graphrag/index/graph/extractors/graph/__init__.py b/func-app/graphrag/index/graph/extractors/graph/__init__.py new file mode 100644 index 0000000000..94e03ab9f7 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/graph/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine unipartite graph package root.""" + +from .graph_extractor import ( + DEFAULT_ENTITY_TYPES, + GraphExtractionResult, + GraphExtractor, +) +from .prompts import GRAPH_EXTRACTION_PROMPT + +__all__ = [ + "DEFAULT_ENTITY_TYPES", + "GRAPH_EXTRACTION_PROMPT", + "GraphExtractionResult", + "GraphExtractor", +] diff --git a/func-app/graphrag/index/graph/extractors/graph/graph_extractor.py b/func-app/graphrag/index/graph/extractors/graph/graph_extractor.py new file mode 100644 index 0000000000..f1ba0011f9 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/graph/graph_extractor.py @@ -0,0 +1,305 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models.""" + +import logging +import numbers +import re +import traceback +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import networkx as nx +import tiktoken + +import graphrag.config.defaults as defs +from graphrag.index.typing import ErrorHandlerFn +from graphrag.index.utils import clean_str +from graphrag.llm import CompletionLLM + +from .prompts import CONTINUE_PROMPT, GRAPH_EXTRACTION_PROMPT, LOOP_PROMPT + +DEFAULT_TUPLE_DELIMITER = "<|>" +DEFAULT_RECORD_DELIMITER = "##" +DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" +DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] + + +@dataclass +class GraphExtractionResult: + """Unipartite graph extraction result class definition.""" + + output: nx.Graph + source_docs: dict[Any, Any] + + +class GraphExtractor: + """Unipartite graph extractor class definition.""" + + _llm: CompletionLLM + _join_descriptions: bool + _tuple_delimiter_key: str + _record_delimiter_key: str + _entity_types_key: str + _input_text_key: str + _completion_delimiter_key: str + _entity_name_key: str + _input_descriptions_key: str + _extraction_prompt: str + _summarization_prompt: str + _loop_args: dict[str, Any] + _max_gleanings: int + _on_error: ErrorHandlerFn + + def __init__( + self, + llm_invoker: CompletionLLM, + tuple_delimiter_key: str | None = None, + record_delimiter_key: str | None = None, + input_text_key: str | None = None, + entity_types_key: str | None = None, + completion_delimiter_key: str | None = None, + prompt: str | None = None, + join_descriptions=True, + encoding_model: str | None = None, + max_gleanings: int | None = None, + on_error: ErrorHandlerFn | None = None, + ): + """Init method definition.""" + # TODO: streamline construction + self._llm = llm_invoker + self._join_descriptions = join_descriptions + self._input_text_key = input_text_key or "input_text" + self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" + self._record_delimiter_key = record_delimiter_key or "record_delimiter" + self._completion_delimiter_key = ( + completion_delimiter_key or "completion_delimiter" + ) + self._entity_types_key = entity_types_key or "entity_types" + self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT + self._max_gleanings = ( + max_gleanings + if max_gleanings is not None + else defs.ENTITY_EXTRACTION_MAX_GLEANINGS + ) + self._on_error = on_error or (lambda _e, _s, _d: None) + + # Construct the looping arguments + encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") + yes = encoding.encode("YES") + no = encoding.encode("NO") + self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} + + async def __call__( + self, texts: list[str], prompt_variables: dict[str, Any] | None = None + ) -> GraphExtractionResult: + """Call method definition.""" + if prompt_variables is None: + prompt_variables = {} + all_records: dict[int, str] = {} + source_doc_map: dict[int, str] = {} + + # Wire defaults into the prompt variables + prompt_variables = { + **prompt_variables, + self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) + or DEFAULT_TUPLE_DELIMITER, + self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) + or DEFAULT_RECORD_DELIMITER, + self._completion_delimiter_key: prompt_variables.get( + self._completion_delimiter_key + ) + or DEFAULT_COMPLETION_DELIMITER, + self._entity_types_key: ",".join( + prompt_variables[self._entity_types_key] or DEFAULT_ENTITY_TYPES + ), + } + + for doc_index, text in enumerate(texts): + try: + # Invoke the entity extraction + result = await self._process_document(text, prompt_variables) + source_doc_map[doc_index] = text + all_records[doc_index] = result + except Exception as e: + logging.exception("error extracting graph") + self._on_error( + e, + traceback.format_exc(), + { + "doc_index": doc_index, + "text": text, + }, + ) + + output = await self._process_results( + all_records, + prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER), + prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), + ) + + return GraphExtractionResult( + output=output, + source_docs=source_doc_map, + ) + + async def _process_document( + self, text: str, prompt_variables: dict[str, str] + ) -> str: + response = await self._llm( + self._extraction_prompt, + variables={ + **prompt_variables, + self._input_text_key: text, + }, + ) + results = response.output or "" + + # Repeat to ensure we maximize entity count + for i in range(self._max_gleanings): + response = await self._llm( + CONTINUE_PROMPT, + name=f"extract-continuation-{i}", + history=response.history, + ) + results += response.output or "" + + # if this is the final glean, don't bother updating the continuation flag + if i >= self._max_gleanings - 1: + break + + response = await self._llm( + LOOP_PROMPT, + name=f"extract-loopcheck-{i}", + history=response.history, + model_parameters=self._loop_args, + ) + if response.output != "YES": + break + + return results + + async def _process_results( + self, + results: dict[int, str], + tuple_delimiter: str, + record_delimiter: str, + ) -> nx.Graph: + """Parse the result string to create an undirected unipartite graph. + + Args: + - results - dict of results from the extraction chain + - tuple_delimiter - delimiter between tuples in an output record, default is '<|>' + - record_delimiter - delimiter between records, default is '##' + Returns: + - output - unipartite graph in graphML format + """ + graph = nx.Graph() + for source_doc_id, extracted_data in results.items(): + records = [r.strip() for r in extracted_data.split(record_delimiter)] + + for record in records: + record = re.sub(r"^\(|\)$", "", record.strip()) + record_attributes = record.split(tuple_delimiter) + + if record_attributes[0] == '"entity"' and len(record_attributes) >= 4: + # add this record as a node in the G + entity_name = clean_str(record_attributes[1].upper()) + entity_type = clean_str(record_attributes[2].upper()) + entity_description = clean_str(record_attributes[3]) + + if entity_name in graph.nodes(): + node = graph.nodes[entity_name] + if self._join_descriptions: + node["description"] = "\n".join( + list({ + *_unpack_descriptions(node), + entity_description, + }) + ) + else: + if len(entity_description) > len(node["description"]): + node["description"] = entity_description + node["source_id"] = ", ".join( + list({ + *_unpack_source_ids(node), + str(source_doc_id), + }) + ) + node["entity_type"] = ( + entity_type if entity_type != "" else node["entity_type"] + ) + else: + graph.add_node( + entity_name, + type=entity_type, + description=entity_description, + source_id=str(source_doc_id), + ) + + if ( + record_attributes[0] == '"relationship"' + and len(record_attributes) >= 5 + ): + # add this record as edge + source = clean_str(record_attributes[1].upper()) + target = clean_str(record_attributes[2].upper()) + edge_description = clean_str(record_attributes[3]) + edge_source_id = clean_str(str(source_doc_id)) + weight = ( + float(record_attributes[-1]) + if isinstance(record_attributes[-1], numbers.Number) + else 1.0 + ) + if source not in graph.nodes(): + graph.add_node( + source, + type="", + description="", + source_id=edge_source_id, + ) + if target not in graph.nodes(): + graph.add_node( + target, + type="", + description="", + source_id=edge_source_id, + ) + if graph.has_edge(source, target): + edge_data = graph.get_edge_data(source, target) + if edge_data is not None: + weight += edge_data["weight"] + if self._join_descriptions: + edge_description = "\n".join( + list({ + *_unpack_descriptions(edge_data), + edge_description, + }) + ) + edge_source_id = ", ".join( + list({ + *_unpack_source_ids(edge_data), + str(source_doc_id), + }) + ) + graph.add_edge( + source, + target, + weight=weight, + description=edge_description, + source_id=edge_source_id, + ) + + return graph + + +def _unpack_descriptions(data: Mapping) -> list[str]: + value = data.get("description", None) + return [] if value is None else value.split("\n") + + +def _unpack_source_ids(data: Mapping) -> list[str]: + value = data.get("source_id", None) + return [] if value is None else value.split(", ") diff --git a/func-app/graphrag/index/graph/extractors/graph/prompts.py b/func-app/graphrag/index/graph/extractors/graph/prompts.py new file mode 100644 index 0000000000..cb1bcc668a --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/graph/prompts.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing prompts definition.""" + +GRAPH_EXTRACTION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity + Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +###################### +-Examples- +###################### +Example 1: +Entity_types: ORGANIZATION,PERSON +Text: +The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. +###################### +Output: +("entity"{tuple_delimiter}CENTRAL INSTITUTION{tuple_delimiter}ORGANIZATION{tuple_delimiter}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) +{record_delimiter} +("entity"{tuple_delimiter}MARTIN SMITH{tuple_delimiter}PERSON{tuple_delimiter}Martin Smith is the chair of the Central Institution) +{record_delimiter} +("entity"{tuple_delimiter}MARKET STRATEGY COMMITTEE{tuple_delimiter}ORGANIZATION{tuple_delimiter}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) +{record_delimiter} +("relationship"{tuple_delimiter}MARTIN SMITH{tuple_delimiter}CENTRAL INSTITUTION{tuple_delimiter}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{tuple_delimiter}9) +{completion_delimiter} + +###################### +Example 2: +Entity_types: ORGANIZATION +Text: +TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform. + +TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. +###################### +Output: +("entity"{tuple_delimiter}TECHGLOBAL{tuple_delimiter}ORGANIZATION{tuple_delimiter}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) +{record_delimiter} +("entity"{tuple_delimiter}VISION HOLDINGS{tuple_delimiter}ORGANIZATION{tuple_delimiter}Vision Holdings is a firm that previously owned TechGlobal) +{record_delimiter} +("relationship"{tuple_delimiter}TECHGLOBAL{tuple_delimiter}VISION HOLDINGS{tuple_delimiter}Vision Holdings formerly owned TechGlobal from 2014 until present{tuple_delimiter}5) +{completion_delimiter} + +###################### +Example 3: +Entity_types: ORGANIZATION,GEO,PERSON +Text: +Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia. + +The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara. + +The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara. + +They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion. + +The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. +###################### +Output: +("entity"{tuple_delimiter}FIRUZABAD{tuple_delimiter}GEO{tuple_delimiter}Firuzabad held Aurelians as hostages) +{record_delimiter} +("entity"{tuple_delimiter}AURELIA{tuple_delimiter}GEO{tuple_delimiter}Country seeking to release hostages) +{record_delimiter} +("entity"{tuple_delimiter}QUINTARA{tuple_delimiter}GEO{tuple_delimiter}Country that negotiated a swap of money in exchange for hostages) +{record_delimiter} +{record_delimiter} +("entity"{tuple_delimiter}TIRUZIA{tuple_delimiter}GEO{tuple_delimiter}Capital of Firuzabad where the Aurelians were being held) +{record_delimiter} +("entity"{tuple_delimiter}KROHAARA{tuple_delimiter}GEO{tuple_delimiter}Capital city in Quintara) +{record_delimiter} +("entity"{tuple_delimiter}CASHION{tuple_delimiter}GEO{tuple_delimiter}Capital city in Aurelia) +{record_delimiter} +("entity"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}PERSON{tuple_delimiter}Aurelian who spent time in Tiruzia's Alhamia Prison) +{record_delimiter} +("entity"{tuple_delimiter}ALHAMIA PRISON{tuple_delimiter}GEO{tuple_delimiter}Prison in Tiruzia) +{record_delimiter} +("entity"{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}PERSON{tuple_delimiter}Aurelian journalist who was held hostage) +{record_delimiter} +("entity"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}PERSON{tuple_delimiter}Bratinas national and environmentalist who was held hostage) +{record_delimiter} +("relationship"{tuple_delimiter}FIRUZABAD{tuple_delimiter}AURELIA{tuple_delimiter}Firuzabad negotiated a hostage exchange with Aurelia{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}QUINTARA{tuple_delimiter}AURELIA{tuple_delimiter}Quintara brokered the hostage exchange between Firuzabad and Aurelia{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}QUINTARA{tuple_delimiter}FIRUZABAD{tuple_delimiter}Quintara brokered the hostage exchange between Firuzabad and Aurelia{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}ALHAMIA PRISON{tuple_delimiter}Samuel Namara was a prisoner at Alhamia prison{tuple_delimiter}8) +{record_delimiter} +("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}FIRUZABAD{tuple_delimiter}Samuel Namara was a hostage in Firuzabad{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}FIRUZABAD{tuple_delimiter}Meggie Tazbah was a hostage in Firuzabad{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}FIRUZABAD{tuple_delimiter}Durke Bataglani was a hostage in Firuzabad{tuple_delimiter}2) +{completion_delimiter} + +###################### +-Real Data- +###################### +Entity_types: {entity_types} +Text: {input_text} +###################### +Output:""" + +CONTINUE_PROMPT = "MANY entities and relationships were missed in the last extraction. Remember to ONLY emit entities that match any of the previously extracted types. Add them below using the same format:\n" +LOOP_PROMPT = "It appears some entities and relationships may have still been missed. Answer YES | NO if there are still entities or relationships that need to be added.\n" diff --git a/func-app/graphrag/index/graph/extractors/summarize/__init__.py b/func-app/graphrag/index/graph/extractors/summarize/__init__.py new file mode 100644 index 0000000000..b4bfe5be87 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/summarize/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine unipartite graph package root.""" + +from .description_summary_extractor import ( + SummarizationResult, + SummarizeExtractor, +) +from .prompts import SUMMARIZE_PROMPT + +__all__ = ["SUMMARIZE_PROMPT", "SummarizationResult", "SummarizeExtractor"] diff --git a/func-app/graphrag/index/graph/extractors/summarize/description_summary_extractor.py b/func-app/graphrag/index/graph/extractors/summarize/description_summary_extractor.py new file mode 100644 index 0000000000..76d77202d3 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/summarize/description_summary_extractor.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models.""" + +import json +from dataclasses import dataclass + +from graphrag.index.typing import ErrorHandlerFn +from graphrag.index.utils.tokens import num_tokens_from_string +from graphrag.llm import CompletionLLM + +from .prompts import SUMMARIZE_PROMPT + +# Max token size for input prompts +DEFAULT_MAX_INPUT_TOKENS = 4_000 +# Max token count for LLM answers +DEFAULT_MAX_SUMMARY_LENGTH = 500 + + +@dataclass +class SummarizationResult: + """Unipartite graph extraction result class definition.""" + + items: str | tuple[str, str] + description: str + + +class SummarizeExtractor: + """Unipartite graph extractor class definition.""" + + _llm: CompletionLLM + _entity_name_key: str + _input_descriptions_key: str + _summarization_prompt: str + _on_error: ErrorHandlerFn + _max_summary_length: int + _max_input_tokens: int + + def __init__( + self, + llm_invoker: CompletionLLM, + entity_name_key: str | None = None, + input_descriptions_key: str | None = None, + summarization_prompt: str | None = None, + on_error: ErrorHandlerFn | None = None, + max_summary_length: int | None = None, + max_input_tokens: int | None = None, + ): + """Init method definition.""" + # TODO: streamline construction + self._llm = llm_invoker + self._entity_name_key = entity_name_key or "entity_name" + self._input_descriptions_key = input_descriptions_key or "description_list" + + self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT + self._on_error = on_error or (lambda _e, _s, _d: None) + self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH + self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS + + async def __call__( + self, + items: str | tuple[str, str], + descriptions: list[str], + ) -> SummarizationResult: + """Call method definition.""" + result = "" + if len(descriptions) == 0: + result = "" + if len(descriptions) == 1: + result = descriptions[0] + else: + result = await self._summarize_descriptions(items, descriptions) + + return SummarizationResult( + items=items, + description=result or "", + ) + + async def _summarize_descriptions( + self, items: str | tuple[str, str], descriptions: list[str] + ) -> str: + """Summarize descriptions into a single description.""" + sorted_items = sorted(items) if isinstance(items, list) else items + + # Safety check, should always be a list + if not isinstance(descriptions, list): + descriptions = [descriptions] + + # Iterate over descriptions, adding all until the max input tokens is reached + usable_tokens = self._max_input_tokens - num_tokens_from_string( + self._summarization_prompt + ) + descriptions_collected = [] + result = "" + + for i, description in enumerate(descriptions): + usable_tokens -= num_tokens_from_string(description) + descriptions_collected.append(description) + + # If buffer is full, or all descriptions have been added, summarize + if (usable_tokens < 0 and len(descriptions_collected) > 1) or ( + i == len(descriptions) - 1 + ): + # Calculate result (final or partial) + result = await self._summarize_descriptions_with_llm( + sorted_items, descriptions_collected + ) + + # If we go for another loop, reset values to new + if i != len(descriptions) - 1: + descriptions_collected = [result] + usable_tokens = ( + self._max_input_tokens + - num_tokens_from_string(self._summarization_prompt) + - num_tokens_from_string(result) + ) + + return result + + async def _summarize_descriptions_with_llm( + self, items: str | tuple[str, str] | list[str], descriptions: list[str] + ): + """Summarize descriptions using the LLM.""" + response = await self._llm( + self._summarization_prompt, + name="summarize", + variables={ + self._entity_name_key: json.dumps(items), + self._input_descriptions_key: json.dumps(sorted(descriptions)), + }, + model_parameters={"max_tokens": self._max_summary_length}, + ) + # Calculate result + return str(response.output) diff --git a/func-app/graphrag/index/graph/extractors/summarize/prompts.py b/func-app/graphrag/index/graph/extractors/summarize/prompts.py new file mode 100644 index 0000000000..90e4434ee8 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/summarize/prompts.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing prompts definition.""" + +SUMMARIZE_PROMPT = """ +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Entities: {entity_name} +Description List: {description_list} +####### +Output: +""" diff --git a/func-app/graphrag/index/graph/utils/__init__.py b/func-app/graphrag/index/graph/utils/__init__.py new file mode 100644 index 0000000000..6d4479283a --- /dev/null +++ b/func-app/graphrag/index/graph/utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph utils package root.""" + +from .normalize_node_names import normalize_node_names +from .stable_lcc import stable_largest_connected_component + +__all__ = ["normalize_node_names", "stable_largest_connected_component"] diff --git a/func-app/graphrag/index/graph/utils/normalize_node_names.py b/func-app/graphrag/index/graph/utils/normalize_node_names.py new file mode 100644 index 0000000000..bcc874a927 --- /dev/null +++ b/func-app/graphrag/index/graph/utils/normalize_node_names.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing normalize_node_names method definition.""" + +import html + +import networkx as nx + + +def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph: + """Normalize node names.""" + node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore + return nx.relabel_nodes(graph, node_mapping) diff --git a/func-app/graphrag/index/graph/utils/stable_lcc.py b/func-app/graphrag/index/graph/utils/stable_lcc.py new file mode 100644 index 0000000000..7d602a6ba7 --- /dev/null +++ b/func-app/graphrag/index/graph/utils/stable_lcc.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module for producing a stable largest connected component, i.e. same input graph == same output lcc.""" + +from typing import Any, cast + +import networkx as nx +from graspologic.utils import largest_connected_component + +from .normalize_node_names import normalize_node_names + + +def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: + """Return the largest connected component of the graph, with nodes and edges sorted in a stable way.""" + graph = graph.copy() + graph = cast(nx.Graph, largest_connected_component(graph)) + graph = normalize_node_names(graph) + return _stabilize_graph(graph) + + +def _stabilize_graph(graph: nx.Graph) -> nx.Graph: + """Ensure an undirected graph with the same relationships will always be read the same way.""" + fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() + + sorted_nodes = graph.nodes(data=True) + sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) + + fixed_graph.add_nodes_from(sorted_nodes) + edges = list(graph.edges(data=True)) + + # If the graph is undirected, we create the edges in a stable way, so we get the same results + # for example: + # A -> B + # in graph theory is the same as + # B -> A + # in an undirected graph + # however, this can lead to downstream issues because sometimes + # consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A] + # but they base some of their logic on the order of the nodes, so the order ends up being important + # so we sort the nodes in the edge in a stable way, so that we always get the same order + if not graph.is_directed(): + + def _sort_source_target(edge): + source, target, edge_data = edge + if source > target: + temp = source + source = target + target = temp + return source, target, edge_data + + edges = [_sort_source_target(edge) for edge in edges] + + def _get_edge_key(source: Any, target: Any) -> str: + return f"{source} -> {target}" + + edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) + + fixed_graph.add_edges_from(edges) + return fixed_graph diff --git a/func-app/graphrag/index/graph/visualization/__init__.py b/func-app/graphrag/index/graph/visualization/__init__.py new file mode 100644 index 0000000000..f7780e4e9c --- /dev/null +++ b/func-app/graphrag/index/graph/visualization/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph visualization package root.""" + +from .compute_umap_positions import compute_umap_positions, get_zero_positions +from .typing import GraphLayout, NodePosition + +__all__ = [ + "GraphLayout", + "NodePosition", + "compute_umap_positions", + "get_zero_positions", +] diff --git a/func-app/graphrag/index/graph/visualization/compute_umap_positions.py b/func-app/graphrag/index/graph/visualization/compute_umap_positions.py new file mode 100644 index 0000000000..569b7b309d --- /dev/null +++ b/func-app/graphrag/index/graph/visualization/compute_umap_positions.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing compute_umap_positions and visualize_embedding method definition.""" + +import graspologic as gc +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np +import umap + +from .typing import NodePosition + + +def get_zero_positions( + node_labels: list[str], + node_categories: list[int] | None = None, + node_sizes: list[int] | None = None, + three_d: bool | None = False, +) -> list[NodePosition]: + """Project embedding vectors down to 2D/3D using UMAP.""" + embedding_position_data: list[NodePosition] = [] + for index, node_name in enumerate(node_labels): + node_category = 1 if node_categories is None else node_categories[index] + node_size = 1 if node_sizes is None else node_sizes[index] + + if not three_d: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=0, + y=0, + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + else: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=0, + y=0, + z=0, + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + return embedding_position_data + + +def compute_umap_positions( + embedding_vectors: np.ndarray, + node_labels: list[str], + node_categories: list[int] | None = None, + node_sizes: list[int] | None = None, + min_dist: float = 0.75, + n_neighbors: int = 25, + spread: int = 1, + metric: str = "euclidean", + n_components: int = 2, + random_state: int = 86, +) -> list[NodePosition]: + """Project embedding vectors down to 2D/3D using UMAP.""" + embedding_positions = umap.UMAP( + min_dist=min_dist, + n_neighbors=n_neighbors, + spread=spread, + n_components=n_components, + metric=metric, + random_state=random_state, + ).fit_transform(embedding_vectors) + + embedding_position_data: list[NodePosition] = [] + for index, node_name in enumerate(node_labels): + node_points = embedding_positions[index] # type: ignore + node_category = 1 if node_categories is None else node_categories[index] + node_size = 1 if node_sizes is None else node_sizes[index] + + if len(node_points) == 2: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=float(node_points[0]), + y=float(node_points[1]), + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + else: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=float(node_points[0]), + y=float(node_points[1]), + z=float(node_points[2]), + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + return embedding_position_data + + +def visualize_embedding( + graph, + umap_positions: list[dict], +): + """Project embedding down to 2D using UMAP and visualize.""" + # rendering + plt.clf() + figure = plt.gcf() + ax = plt.gca() + + ax.set_axis_off() + figure.set_size_inches(10, 10) + figure.set_dpi(400) + + node_position_dict = { + (str)(position["label"]): (position["x"], position["y"]) + for position in umap_positions + } + node_category_dict = { + (str)(position["label"]): position["category"] for position in umap_positions + } + node_sizes = [position["size"] for position in umap_positions] + node_colors = gc.layouts.categorical_colors(node_category_dict) # type: ignore + + vertices = [] + node_color_list = [] + for node in node_position_dict: + vertices.append(node) + node_color_list.append(node_colors[node]) + + nx.draw_networkx_nodes( + graph, + pos=node_position_dict, + nodelist=vertices, + node_color=node_color_list, # type: ignore + alpha=1.0, + linewidths=0.01, + node_size=node_sizes, # type: ignore + node_shape="o", + ax=ax, + ) + plt.show() diff --git a/func-app/graphrag/index/graph/visualization/typing.py b/func-app/graphrag/index/graph/visualization/typing.py new file mode 100644 index 0000000000..ae46afa928 --- /dev/null +++ b/func-app/graphrag/index/graph/visualization/typing.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +# Use this for now instead of a wrapper +"""A module containing 'NodePosition' model.""" + +from dataclasses import dataclass + + +@dataclass +class NodePosition: + """Node position class definition.""" + + label: str + cluster: str + size: float + + x: float + y: float + z: float | None = None + + def to_pandas(self) -> tuple[str, float, float, str, float]: + """To pandas method definition.""" + return self.label, self.x, self.y, self.cluster, self.size + + +GraphLayout = list[NodePosition] diff --git a/func-app/graphrag/index/init_content.py b/func-app/graphrag/index/init_content.py new file mode 100644 index 0000000000..8f5982f8f7 --- /dev/null +++ b/func-app/graphrag/index/init_content.py @@ -0,0 +1,176 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Content for the init CLI command.""" + +import graphrag.config.defaults as defs + +INIT_YAML = f""" +encoding_model: cl100k_base +skip_workflows: [] +llm: + api_key: ${{GRAPHRAG_API_KEY}} + type: {defs.LLM_TYPE.value} # or azure_openai_chat + model: {defs.LLM_MODEL} + model_supports_json: true # recommended if this is available for your model. + # max_tokens: {defs.LLM_MAX_TOKENS} + # request_timeout: {defs.LLM_REQUEST_TIMEOUT} + # api_base: https://.openai.azure.com + # api_version: 2024-02-15-preview + # organization: + # deployment_name: + # tokens_per_minute: 150_000 # set a leaky bucket throttle + # requests_per_minute: 10_000 # set a leaky bucket throttle + # max_retries: {defs.LLM_MAX_RETRIES} + # max_retry_wait: {defs.LLM_MAX_RETRY_WAIT} + # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times + # concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # the number of parallel inflight requests that may be made + # temperature: {defs.LLM_TEMPERATURE} # temperature for sampling + # top_p: {defs.LLM_TOP_P} # top-p sampling + # n: {defs.LLM_N} # Number of completions to generate + +parallelization: + stagger: {defs.PARALLELIZATION_STAGGER} + # num_threads: {defs.PARALLELIZATION_NUM_THREADS} # the number of threads to use for parallel processing + +async_mode: {defs.ASYNC_MODE.value} # or asyncio + +embeddings: + ## parallelization: override the global parallelization settings for embeddings + async_mode: {defs.ASYNC_MODE.value} # or asyncio + llm: + api_key: ${{GRAPHRAG_API_KEY}} + type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding + model: {defs.EMBEDDING_MODEL} + # api_base: https://.openai.azure.com + # api_version: 2024-02-15-preview + # organization: + # deployment_name: + # tokens_per_minute: 150_000 # set a leaky bucket throttle + # requests_per_minute: 10_000 # set a leaky bucket throttle + # max_retries: {defs.LLM_MAX_RETRIES} + # max_retry_wait: {defs.LLM_MAX_RETRY_WAIT} + # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times + # concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # the number of parallel inflight requests that may be made + # batch_size: {defs.EMBEDDING_BATCH_SIZE} # the number of documents to send in a single request + # batch_max_tokens: {defs.EMBEDDING_BATCH_MAX_TOKENS} # the maximum number of tokens to send in a single request + # target: {defs.EMBEDDING_TARGET.value} # or optional + + + +chunks: + size: {defs.CHUNK_SIZE} + overlap: {defs.CHUNK_OVERLAP} + group_by_columns: [{",".join(defs.CHUNK_GROUP_BY_COLUMNS)}] # by default, we don't allow chunks to cross documents + +input: + type: {defs.INPUT_TYPE.value} # or blob + file_type: {defs.INPUT_FILE_TYPE.value} # or csv + base_dir: "{defs.INPUT_BASE_DIR}" + file_encoding: {defs.INPUT_FILE_ENCODING} + file_pattern: ".*\\\\.txt$" + +cache: + type: {defs.CACHE_TYPE.value} # or blob + base_dir: "{defs.CACHE_BASE_DIR}" + # connection_string: + # container_name: + +storage: + type: {defs.STORAGE_TYPE.value} # or blob + base_dir: "{defs.STORAGE_BASE_DIR}" + # connection_string: + # container_name: + +reporting: + type: {defs.REPORTING_TYPE.value} # or console, blob + base_dir: "{defs.REPORTING_BASE_DIR}" + # connection_string: + # container_name: + +entity_extraction: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/entity_extraction.txt" + entity_types: [{",".join(defs.ENTITY_EXTRACTION_ENTITY_TYPES)}] + max_gleanings: {defs.ENTITY_EXTRACTION_MAX_GLEANINGS} + +summarize_descriptions: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/summarize_descriptions.txt" + max_length: {defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH} + +claim_extraction: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + # enabled: true + prompt: "prompts/claim_extraction.txt" + description: "{defs.CLAIM_DESCRIPTION}" + max_gleanings: {defs.CLAIM_MAX_GLEANINGS} + +community_reports: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/community_report.txt" + max_length: {defs.COMMUNITY_REPORT_MAX_LENGTH} + max_input_length: {defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH} + +cluster_graph: + max_cluster_size: {defs.MAX_CLUSTER_SIZE} + +embed_graph: + enabled: false # if true, will generate node2vec embeddings for nodes + # num_walks: {defs.NODE2VEC_NUM_WALKS} + # walk_length: {defs.NODE2VEC_WALK_LENGTH} + # window_size: {defs.NODE2VEC_WINDOW_SIZE} + # iterations: {defs.NODE2VEC_ITERATIONS} + # random_seed: {defs.NODE2VEC_RANDOM_SEED} + +umap: + enabled: false # if true, will generate UMAP embeddings for nodes + +snapshots: + graphml: false + raw_entities: false + top_level_nodes: false + +local_search: + # text_unit_prop: {defs.LOCAL_SEARCH_TEXT_UNIT_PROP} + # community_prop: {defs.LOCAL_SEARCH_COMMUNITY_PROP} + # conversation_history_max_turns: {defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS} + # top_k_mapped_entities: {defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES} + # top_k_relationships: {defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS} + # llm_temperature: {defs.LOCAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling + # llm_top_p: {defs.LOCAL_SEARCH_LLM_TOP_P} # top-p sampling + # llm_n: {defs.LOCAL_SEARCH_LLM_N} # Number of completions to generate + # max_tokens: {defs.LOCAL_SEARCH_MAX_TOKENS} + +global_search: + # llm_temperature: {defs.GLOBAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling + # llm_top_p: {defs.GLOBAL_SEARCH_LLM_TOP_P} # top-p sampling + # llm_n: {defs.GLOBAL_SEARCH_LLM_N} # Number of completions to generate + # max_tokens: {defs.GLOBAL_SEARCH_MAX_TOKENS} + # data_max_tokens: {defs.GLOBAL_SEARCH_DATA_MAX_TOKENS} + # map_max_tokens: {defs.GLOBAL_SEARCH_MAP_MAX_TOKENS} + # reduce_max_tokens: {defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS} + # concurrency: {defs.GLOBAL_SEARCH_CONCURRENCY} + +query_context: + # Files: [] # list of files in context to run query + +graphdb: + account_name: '' + account_key: '' + username: '' + enabled: false + cosmos_url: '' + gremlin_url: '' +""" + +INIT_DOTENV = """ +GRAPHRAG_API_KEY= +""" diff --git a/func-app/graphrag/index/input/__init__.py b/func-app/graphrag/index/input/__init__.py new file mode 100644 index 0000000000..91421867de --- /dev/null +++ b/func-app/graphrag/index/input/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine input package root.""" + +from .load_input import load_input + +__all__ = ["load_input"] diff --git a/func-app/graphrag/index/input/csv.py b/func-app/graphrag/index/input/csv.py new file mode 100644 index 0000000000..04f43ddda0 --- /dev/null +++ b/func-app/graphrag/index/input/csv.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load method definition.""" + +import logging +import re +from io import BytesIO +from typing import cast + +import pandas as pd + +from graphrag.index.config import PipelineCSVInputConfig, PipelineInputConfig +from graphrag.common.progress import ProgressReporter +from graphrag.common.storage import PipelineStorage +from graphrag.index.utils import gen_md5_hash + +log = logging.getLogger(__name__) + +DEFAULT_FILE_PATTERN = re.compile(r"(?P[^\\/]).csv$") + +input_type = "csv" + + +async def load( + config: PipelineInputConfig, + progress: ProgressReporter | None, + storage: PipelineStorage, +) -> pd.DataFrame: + """Load csv inputs from a directory.""" + csv_config = cast(PipelineCSVInputConfig, config) + log.info("Loading csv files from %s", csv_config.base_dir) + + async def load_file(path: str, group: dict | None) -> pd.DataFrame: + if group is None: + group = {} + buffer = BytesIO(await storage.get(path, as_bytes=True)) + data = pd.read_csv(buffer, encoding=config.encoding or "latin-1") + additional_keys = group.keys() + if len(additional_keys) > 0: + data[[*additional_keys]] = data.apply( + lambda _row: pd.Series([group[key] for key in additional_keys]), axis=1 + ) + if "id" not in data.columns: + data["id"] = data.apply(lambda x: gen_md5_hash(x, x.keys()), axis=1) + if csv_config.source_column is not None and "source" not in data.columns: + if csv_config.source_column not in data.columns: + log.warning( + "source_column %s not found in csv file %s", + csv_config.source_column, + path, + ) + else: + data["source"] = data.apply( + lambda x: x[csv_config.source_column], axis=1 + ) + if csv_config.text_column is not None and "text" not in data.columns: + if csv_config.text_column not in data.columns: + log.warning( + "text_column %s not found in csv file %s", + csv_config.text_column, + path, + ) + else: + data["text"] = data.apply(lambda x: x[csv_config.text_column], axis=1) + if csv_config.title_column is not None and "title" not in data.columns: + if csv_config.title_column not in data.columns: + log.warning( + "title_column %s not found in csv file %s", + csv_config.title_column, + path, + ) + else: + data["title"] = data.apply(lambda x: x[csv_config.title_column], axis=1) + + if csv_config.timestamp_column is not None: + fmt = csv_config.timestamp_format + if fmt is None: + msg = "Must specify timestamp_format if timestamp_column is specified" + raise ValueError(msg) + + if csv_config.timestamp_column not in data.columns: + log.warning( + "timestamp_column %s not found in csv file %s", + csv_config.timestamp_column, + path, + ) + else: + data["timestamp"] = pd.to_datetime( + data[csv_config.timestamp_column], format=fmt + ) + + # TODO: Theres probably a less gross way to do this + if "year" not in data.columns: + data["year"] = data.apply(lambda x: x["timestamp"].year, axis=1) + if "month" not in data.columns: + data["month"] = data.apply(lambda x: x["timestamp"].month, axis=1) + if "day" not in data.columns: + data["day"] = data.apply(lambda x: x["timestamp"].day, axis=1) + if "hour" not in data.columns: + data["hour"] = data.apply(lambda x: x["timestamp"].hour, axis=1) + if "minute" not in data.columns: + data["minute"] = data.apply(lambda x: x["timestamp"].minute, axis=1) + if "second" not in data.columns: + data["second"] = data.apply(lambda x: x["timestamp"].second, axis=1) + + return data + + file_pattern = ( + re.compile(config.file_pattern) + if config.file_pattern is not None + else DEFAULT_FILE_PATTERN + ) + files = list( + storage.find( + file_pattern, + progress=progress, + file_filter=config.file_filter, + ) + ) + + if len(files) == 0: + msg = f"No CSV files found in {config.base_dir}" + raise ValueError(msg) + + files_loaded = [] + + for file, group in files: + try: + files_loaded.append(await load_file(file, group)) + except Exception: # noqa: BLE001 (catching Exception is fine here) + log.warning("Warning! Error loading csv file %s. Skipping...", file) + + log.info("Found %d csv files, loading %d", len(files), len(files_loaded)) + result = pd.concat(files_loaded) + total_files_log = f"Total number of unfiltered csv rows: {len(result)}" + log.info(total_files_log) + return result diff --git a/func-app/graphrag/index/input/load_input.py b/func-app/graphrag/index/input/load_input.py new file mode 100644 index 0000000000..14dd90635e --- /dev/null +++ b/func-app/graphrag/index/input/load_input.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load_input method definition.""" + +import logging +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import cast + +import pandas as pd + +from graphrag.config import InputConfig, InputType +from graphrag.index.config import PipelineInputConfig +from graphrag.common.progress import NullProgressReporter, ProgressReporter +from graphrag.common.storage import ( + BlobPipelineStorage, + FilePipelineStorage, +) + +from .csv import input_type as csv +from .csv import load as load_csv +from .text import input_type as text +from .text import load as load_text + +log = logging.getLogger(__name__) +loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = { + text: load_text, + csv: load_csv, +} + + +async def load_input( + config: PipelineInputConfig | InputConfig, + progress_reporter: ProgressReporter | None = None, + root_dir: str | None = None, +) -> pd.DataFrame: + """Load the input data for a pipeline.""" + root_dir = root_dir or "" + log.info(f"loading input from root_dir {root_dir}") + progress_reporter = progress_reporter or NullProgressReporter() + + if config is None: + msg = "No input specified!" + raise ValueError(msg) + + match config.type: + case InputType.blob: + log.info("using blob storage input") + if config.container_name is None: + msg = "Container name required for blob storage" + raise ValueError(msg) + if ( + config.connection_string is None + and config.storage_account_blob_url is None + ): + msg = "Connection string or storage account blob url required for blob storage" + raise ValueError(msg) + storage = BlobPipelineStorage( + connection_string=config.connection_string, + storage_account_blob_url=config.storage_account_blob_url, + container_name=config.container_name, + ) + case InputType.file: + log.info("using file storage for input") + storage = FilePipelineStorage( + root_dir=str(Path(root_dir) / (config.base_dir or "")) + ) + case _: + log.info("using file storage for input") + storage = FilePipelineStorage( + root_dir=str(Path(root_dir) / (config.base_dir or "")) + ) + + if config.file_type in loaders: + progress = progress_reporter.child( + f"Loading Input ({config.file_type})", transient=False + ) + loader = loaders[config.file_type] + results = await loader(config, progress, storage) + return cast(pd.DataFrame, results) + + msg = f"Unknown input type {config.file_type}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/input/text.py b/func-app/graphrag/index/input/text.py new file mode 100644 index 0000000000..3a3e15cbd9 --- /dev/null +++ b/func-app/graphrag/index/input/text.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load method definition.""" + +import logging +import re +from pathlib import Path +from typing import Any + +import pandas as pd + +from graphrag.index.config import PipelineInputConfig +from graphrag.common.progress import ProgressReporter +from graphrag.common.storage import PipelineStorage +from graphrag.index.utils import gen_md5_hash + +DEFAULT_FILE_PATTERN = re.compile( + r".*[\\/](?P[^\\/]+)[\\/](?P\d{4})-(?P\d{2})-(?P\d{2})_(?P[^_]+)_\d+\.txt" +) +input_type = "text" +log = logging.getLogger(__name__) + + +async def load( + config: PipelineInputConfig, + progress: ProgressReporter | None, + storage: PipelineStorage, +) -> pd.DataFrame: + """Load text inputs from a directory.""" + + async def load_file( + path: str, group: dict | None = None, _encoding: str = "utf-8" #what is group here, can be used as context? + ) -> dict[str, Any]: + if group is None: + group = {} + text = await storage.get(path, encoding="utf-8") + new_item = {**group, "text": text} + new_item["id"] = gen_md5_hash(new_item, new_item.keys()) + new_item["title"] = str(Path(path).name) + return new_item + base_dir = config.base_dir + if config.type == "file": + #base dir is already being added to root dir in case of type file. + base_dir = None + files = list( + storage.find( + re.compile(config.file_pattern), + progress=progress, + file_filter=config.file_filter, + base_dir=base_dir + ) + ) + + if len(files) == 0: + msg = f"No text files found in {config.base_dir}" + raise ValueError(msg) + + found_files = f"found text files from {config.base_dir}, found {files}" + log.info(found_files) + + files_loaded = [] + + for file, group in files: + try: + files_loaded.append(await load_file(file, group)) + except Exception: # noqa: BLE001 (catching Exception is fine here) + log.warning("Warning! Error loading file %s. Skipping...", file) + + log.info("Found %d files, loading %d", len(files), len(files_loaded)) + + return pd.DataFrame(files_loaded) diff --git a/func-app/graphrag/index/llm/__init__.py b/func-app/graphrag/index/llm/__init__.py new file mode 100644 index 0000000000..008ef07ccd --- /dev/null +++ b/func-app/graphrag/index/llm/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine LLM package root.""" + +from .load_llm import load_llm, load_llm_embeddings +from .types import TextListSplitter, TextSplitter + +__all__ = [ + "TextListSplitter", + "TextSplitter", + "load_llm", + "load_llm_embeddings", +] diff --git a/func-app/graphrag/index/llm/load_llm.py b/func-app/graphrag/index/llm/load_llm.py new file mode 100644 index 0000000000..264229c887 --- /dev/null +++ b/func-app/graphrag/index/llm/load_llm.py @@ -0,0 +1,313 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Load llm utilities.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING, Any + +from graphrag.config.enums import LLMType +from graphrag.llm import ( + CompletionLLM, + EmbeddingLLM, + LLMCache, + LLMLimiter, + MockCompletionLLM, + OpenAIConfiguration, + create_openai_chat_llm, + create_openai_client, + create_openai_completion_llm, + create_openai_embedding_llm, + create_tpm_rpm_limiters, +) + +if TYPE_CHECKING: + from datashaper import VerbCallbacks + + from graphrag.index.cache import PipelineCache + from graphrag.index.typing import ErrorHandlerFn + +log = logging.getLogger(__name__) + +_semaphores: dict[str, asyncio.Semaphore] = {} +_rate_limiters: dict[str, LLMLimiter] = {} + + +def load_llm( + name: str, + llm_type: LLMType, + callbacks: VerbCallbacks, + cache: PipelineCache | None, + llm_config: dict[str, Any] | None = None, + chat_only=False, +) -> CompletionLLM: + """Load the LLM for the entity extraction chain.""" + on_error = _create_error_handler(callbacks) + + if llm_type in loaders: + if chat_only and not loaders[llm_type]["chat"]: + msg = f"LLM type {llm_type} does not support chat" + raise ValueError(msg) + if cache is not None: + cache = cache.child(name) + + loader = loaders[llm_type] + return loader["load"](on_error, cache, llm_config or {}) + + msg = f"Unknown LLM type {llm_type}" + raise ValueError(msg) + + +def load_llm_embeddings( + name: str, + llm_type: LLMType, + callbacks: VerbCallbacks, + cache: PipelineCache | None, + llm_config: dict[str, Any] | None = None, + chat_only=False, +) -> EmbeddingLLM: + """Load the LLM for the entity extraction chain.""" + on_error = _create_error_handler(callbacks) + if llm_type in loaders: + if chat_only and not loaders[llm_type]["chat"]: + msg = f"LLM type {llm_type} does not support chat" + raise ValueError(msg) + if cache is not None: + cache = cache.child(name) + + return loaders[llm_type]["load"](on_error, cache, llm_config or {}) + + msg = f"Unknown LLM type {llm_type}" + raise ValueError(msg) + + +def _create_error_handler(callbacks: VerbCallbacks) -> ErrorHandlerFn: + def on_error( + error: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ) -> None: + callbacks.error("Error Invoking LLM", error, stack, details) + + return on_error + + +def _load_openai_completion_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], + azure=False, +): + return _create_openai_completion_llm( + OpenAIConfiguration({ + **_get_base_config(config), + "model": config.get("model", "gpt-4-turbo-preview"), + "deployment_name": config.get("deployment_name"), + "temperature": config.get("temperature", 0.0), + "frequency_penalty": config.get("frequency_penalty", 0), + "presence_penalty": config.get("presence_penalty", 0), + "top_p": config.get("top_p", 1), + "max_tokens": config.get("max_tokens", 4000), + "n": config.get("n"), + }), + on_error, + cache, + azure, + ) + + +def _load_openai_chat_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], + azure=False, +): + return _create_openai_chat_llm( + OpenAIConfiguration({ + # Set default values + **_get_base_config(config), + "model": config.get("model", "gpt-4-turbo-preview"), + "deployment_name": config.get("deployment_name"), + "temperature": config.get("temperature", 0.0), + "frequency_penalty": config.get("frequency_penalty", 0), + "presence_penalty": config.get("presence_penalty", 0), + "top_p": config.get("top_p", 1), + "max_tokens": config.get("max_tokens"), + "n": config.get("n"), + }), + on_error, + cache, + azure, + ) + + +def _load_openai_embeddings_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], + azure=False, +): + # TODO: Inject Cache + return _create_openai_embeddings_llm( + OpenAIConfiguration({ + **_get_base_config(config), + "model": config.get( + "embeddings_model", config.get("model", "text-embedding-3-small") + ), + "deployment_name": config.get("deployment_name"), + }), + on_error, + cache, + azure, + ) + + +def _load_azure_openai_completion_llm( + on_error: ErrorHandlerFn, cache: LLMCache, config: dict[str, Any] +): + return _load_openai_completion_llm(on_error, cache, config, True) + + +def _load_azure_openai_chat_llm( + on_error: ErrorHandlerFn, cache: LLMCache, config: dict[str, Any] +): + return _load_openai_chat_llm(on_error, cache, config, True) + + +def _load_azure_openai_embeddings_llm( + on_error: ErrorHandlerFn, cache: LLMCache, config: dict[str, Any] +): + return _load_openai_embeddings_llm(on_error, cache, config, True) + + +def _get_base_config(config: dict[str, Any]) -> dict[str, Any]: + api_key = config.get("api_key") + + return { + # Pass in all parameterized values + **config, + # Set default values + "api_key": api_key, + "api_base": config.get("api_base"), + "api_version": config.get("api_version"), + "organization": config.get("organization"), + "proxy": config.get("proxy"), + "max_retries": config.get("max_retries", 10), + "request_timeout": config.get("request_timeout", 60.0), + "model_supports_json": config.get("model_supports_json"), + "concurrent_requests": config.get("concurrent_requests", 4), + "encoding_model": config.get("encoding_model", "cl100k_base"), + "cognitive_services_endpoint": config.get("cognitive_services_endpoint"), + } + + +def _load_static_response( + _on_error: ErrorHandlerFn, _cache: PipelineCache, config: dict[str, Any] +) -> CompletionLLM: + return MockCompletionLLM(config.get("responses", [])) + + +loaders = { + LLMType.OpenAI: { + "load": _load_openai_completion_llm, + "chat": False, + }, + LLMType.AzureOpenAI: { + "load": _load_azure_openai_completion_llm, + "chat": False, + }, + LLMType.OpenAIChat: { + "load": _load_openai_chat_llm, + "chat": True, + }, + LLMType.AzureOpenAIChat: { + "load": _load_azure_openai_chat_llm, + "chat": True, + }, + LLMType.OpenAIEmbedding: { + "load": _load_openai_embeddings_llm, + "chat": False, + }, + LLMType.AzureOpenAIEmbedding: { + "load": _load_azure_openai_embeddings_llm, + "chat": False, + }, + LLMType.StaticResponse: { + "load": _load_static_response, + "chat": False, + }, +} + + +def _create_openai_chat_llm( + configuration: OpenAIConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, + azure=False, +) -> CompletionLLM: + """Create an openAI chat llm.""" + client = create_openai_client(configuration=configuration, azure=azure) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_openai_chat_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_openai_completion_llm( + configuration: OpenAIConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, + azure=False, +) -> CompletionLLM: + """Create an openAI completion llm.""" + client = create_openai_client(configuration=configuration, azure=azure) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_openai_completion_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_openai_embeddings_llm( + configuration: OpenAIConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, + azure=False, +) -> EmbeddingLLM: + """Create an openAI embeddings llm.""" + client = create_openai_client(configuration=configuration, azure=azure) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_openai_embedding_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter: + limit_name = configuration.model or configuration.deployment_name or "default" + if limit_name not in _rate_limiters: + tpm = configuration.tokens_per_minute + rpm = configuration.requests_per_minute + log.info("create TPM/RPM limiter for %s: TPM=%s, RPM=%s", limit_name, tpm, rpm) + _rate_limiters[limit_name] = create_tpm_rpm_limiters(configuration) + return _rate_limiters[limit_name] + + +def _create_semaphore(configuration: OpenAIConfiguration) -> asyncio.Semaphore | None: + limit_name = configuration.model or configuration.deployment_name or "default" + concurrency = configuration.concurrent_requests + + # bypass the semaphore if concurrency is zero + if not concurrency: + log.info("no concurrency limiter for %s", limit_name) + return None + + if limit_name not in _semaphores: + log.info("create concurrency limiter for %s: %s", limit_name, concurrency) + _semaphores[limit_name] = asyncio.Semaphore(concurrency) + + return _semaphores[limit_name] diff --git a/func-app/graphrag/index/llm/types.py b/func-app/graphrag/index/llm/types.py new file mode 100644 index 0000000000..73c47737cb --- /dev/null +++ b/func-app/graphrag/index/llm/types.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the 'LLMtype' model.""" + +from collections.abc import Callable +from typing import TypeAlias + +TextSplitter: TypeAlias = Callable[[str], list[str]] +TextListSplitter: TypeAlias = Callable[[list[str]], list[str]] diff --git a/func-app/graphrag/index/load_pipeline_config.py b/func-app/graphrag/index/load_pipeline_config.py new file mode 100644 index 0000000000..7488c8c4fe --- /dev/null +++ b/func-app/graphrag/index/load_pipeline_config.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing read_dotenv, load_pipeline_config, _parse_yaml and _create_include_constructor methods definition.""" + +import json +from pathlib import Path +import logging +import yaml +from pyaml_env import parse_config as parse_config_with_env + +from graphrag.config import create_graphrag_config, read_dotenv +from graphrag.index.config import PipelineConfig + +from .create_pipeline_config import create_pipeline_config + +log = logging.getLogger(__name__) +def load_pipeline_config(config_or_path: str | PipelineConfig) -> PipelineConfig: + """Load a pipeline config from a file path or a config object.""" + if isinstance(config_or_path, PipelineConfig): + log.info(f"PipelineConfig Instance Type") + config = config_or_path + elif config_or_path == "default": + config = create_pipeline_config(create_graphrag_config(root_dir=".")) + else: + # Is there a .env file in the same directory as the config? + read_dotenv(str(Path(config_or_path).parent)) + + if config_or_path.endswith(".json"): + with Path(config_or_path).open("rb") as f: + config = json.loads(f.read().decode(encoding="utf-8", errors="strict")) + elif config_or_path.endswith((".yml", ".yaml")): + config = _parse_yaml(config_or_path) + else: + msg = f"Invalid config file type: {config_or_path}" + raise ValueError(msg) + + config = PipelineConfig.model_validate(config) + if not config.root_dir: + config.root_dir = str(Path(config_or_path).parent.resolve()) + + if config.extends is not None: + if isinstance(config.extends, str): + config.extends = [config.extends] + for extended_config in config.extends: + extended_config = load_pipeline_config(extended_config) + merged_config = { + **json.loads(extended_config.model_dump_json()), + **json.loads(config.model_dump_json(exclude_unset=True)), + } + config = PipelineConfig.model_validate(merged_config) + + return config + + +def _parse_yaml(path: str): + """Parse a yaml file, with support for !include directives.""" + # I don't like that this is static + loader_class = yaml.SafeLoader + + # Add !include constructor if not already present. + if "!include" not in loader_class.yaml_constructors: + loader_class.add_constructor("!include", _create_include_constructor()) + + return parse_config_with_env(path, loader=loader_class, default_value="") + + +def _create_include_constructor(): + """Create a constructor for !include directives.""" + + def handle_include(loader: yaml.Loader, node: yaml.Node): + """Include file referenced at node.""" + filename = str(Path(loader.name).parent / node.value) + if filename.endswith((".yml", ".yaml")): + return _parse_yaml(filename) + + with Path(filename).open("rb") as f: + return f.read().decode(encoding="utf-8", errors="strict") + + return handle_include diff --git a/func-app/graphrag/index/py.typed b/func-app/graphrag/index/py.typed new file mode 100644 index 0000000000..f4bd298955 --- /dev/null +++ b/func-app/graphrag/index/py.typed @@ -0,0 +1,2 @@ +# This package supports type hinting, +# see https://www.python.org/dev/peps/pep-0561/#packaging-type-information \ No newline at end of file diff --git a/func-app/graphrag/index/reporting/__init__.py b/func-app/graphrag/index/reporting/__init__.py new file mode 100644 index 0000000000..697d4fc51f --- /dev/null +++ b/func-app/graphrag/index/reporting/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Reporting utilities and implementations for the indexing engine.""" + +from .blob_workflow_callbacks import BlobWorkflowCallbacks +from .console_workflow_callbacks import ConsoleWorkflowCallbacks +from .file_workflow_callbacks import FileWorkflowCallbacks +from .load_pipeline_reporter import load_pipeline_reporter +from .progress_workflow_callbacks import ProgressWorkflowCallbacks + +__all__ = [ + "BlobWorkflowCallbacks", + "ConsoleWorkflowCallbacks", + "FileWorkflowCallbacks", + "ProgressWorkflowCallbacks", + "load_pipeline_reporter", +] diff --git a/func-app/graphrag/index/reporting/blob_workflow_callbacks.py b/func-app/graphrag/index/reporting/blob_workflow_callbacks.py new file mode 100644 index 0000000000..28f0b6d991 --- /dev/null +++ b/func-app/graphrag/index/reporting/blob_workflow_callbacks.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A reporter that writes to a blob storage.""" + +import json +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from datashaper import NoopWorkflowCallbacks + + +class BlobWorkflowCallbacks(NoopWorkflowCallbacks): + """A reporter that writes to a blob storage.""" + + _blob_service_client: BlobServiceClient + _container_name: str + _max_block_count: int = 25000 # 25k blocks per blob + + def __init__( + self, + connection_string: str | None, + container_name: str, + blob_name: str = "", + base_dir: str | None = None, + storage_account_blob_url: str | None = None, + ): # type: ignore + """Create a new instance of the BlobStorageReporter class.""" + if container_name is None: + msg = "No container name provided for blob storage." + raise ValueError(msg) + if connection_string is None and storage_account_blob_url is None: + msg = "No storage account blob url provided for blob storage." + raise ValueError(msg) + self._connection_string = connection_string + self._storage_account_blob_url = storage_account_blob_url + if self._connection_string: + self._blob_service_client = BlobServiceClient.from_connection_string( + self._connection_string + ) + else: + if storage_account_blob_url is None: + msg = "Either connection_string or storage_account_blob_url must be provided." + raise ValueError(msg) + + self._blob_service_client = BlobServiceClient( + storage_account_blob_url, + credential=DefaultAzureCredential(), + ) + + if blob_name == "": + blob_name = f"report/{datetime.now(tz=timezone.utc).strftime('%Y-%m-%d-%H:%M:%S:%f')}.logs.json" + + self._blob_name = str(Path(base_dir or "") / blob_name) + self._container_name = container_name + self._blob_client = self._blob_service_client.get_blob_client( + self._container_name, self._blob_name + ) + if not self._blob_client.exists(): + self._blob_client.create_append_blob() + + self._num_blocks = 0 # refresh block counter + + def _write_log(self, log: dict[str, Any]): + # create a new file when block count hits close 25k + if ( + self._num_blocks >= self._max_block_count + ): # Check if block count exceeds 25k + self.__init__( + self._connection_string, + self._container_name, + storage_account_blob_url=self._storage_account_blob_url, + ) + + blob_client = self._blob_service_client.get_blob_client( + self._container_name, self._blob_name + ) + blob_client.append_block(json.dumps(log) + "\n") + + # update the blob's block count + self._num_blocks += 1 + + def on_error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ): + """Report an error.""" + self._write_log({ + "type": "error", + "data": message, + "cause": str(cause), + "stack": stack, + "details": details, + }) + + def on_warning(self, message: str, details: dict | None = None): + """Report a warning.""" + self._write_log({"type": "warning", "data": message, "details": details}) + + def on_log(self, message: str, details: dict | None = None): + """Report a generic log message.""" + self._write_log({"type": "log", "data": message, "details": details}) diff --git a/func-app/graphrag/index/reporting/console_workflow_callbacks.py b/func-app/graphrag/index/reporting/console_workflow_callbacks.py new file mode 100644 index 0000000000..b1ab1278f7 --- /dev/null +++ b/func-app/graphrag/index/reporting/console_workflow_callbacks.py @@ -0,0 +1,32 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Console-based reporter for the workflow engine.""" + +from datashaper import NoopWorkflowCallbacks + + +class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks): + """A reporter that writes to a console.""" + + def on_error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ): + """Handle when an error occurs.""" + print(message, str(cause), stack, details) # noqa T201 + + def on_warning(self, message: str, details: dict | None = None): + """Handle when a warning occurs.""" + _print_warning(message) + + def on_log(self, message: str, details: dict | None = None): + """Handle when a log message is produced.""" + print(message, details) # noqa T201 + + +def _print_warning(skk): + print("\033[93m {}\033[00m".format(skk)) # noqa T201 diff --git a/func-app/graphrag/index/reporting/file_workflow_callbacks.py b/func-app/graphrag/index/reporting/file_workflow_callbacks.py new file mode 100644 index 0000000000..e659c4f644 --- /dev/null +++ b/func-app/graphrag/index/reporting/file_workflow_callbacks.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A reporter that writes to a file.""" + +import json +import logging +from io import TextIOWrapper +from pathlib import Path + +from datashaper import NoopWorkflowCallbacks + +log = logging.getLogger(__name__) + + +class FileWorkflowCallbacks(NoopWorkflowCallbacks): + """A reporter that writes to a file.""" + + _out_stream: TextIOWrapper + + def __init__(self, directory: str): + """Create a new file-based workflow reporter.""" + Path(directory).mkdir(parents=True, exist_ok=True) + self._out_stream = open( # noqa: PTH123, SIM115 + Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict" + ) + + def on_error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ): + """Handle when an error occurs.""" + self._out_stream.write( + json.dumps({ + "type": "error", + "data": message, + "stack": stack, + "source": str(cause), + "details": details, + }) + + "\n" + ) + message = f"{message} details={details}" + log.info(message) + + def on_warning(self, message: str, details: dict | None = None): + """Handle when a warning occurs.""" + self._out_stream.write( + json.dumps({"type": "warning", "data": message, "details": details}) + "\n" + ) + _print_warning(message) + + def on_log(self, message: str, details: dict | None = None): + """Handle when a log message is produced.""" + self._out_stream.write( + json.dumps({"type": "log", "data": message, "details": details}) + "\n" + ) + + message = f"{message} details={details}" + log.info(message) + + +def _print_warning(skk): + log.warning(skk) diff --git a/func-app/graphrag/index/reporting/load_pipeline_reporter.py b/func-app/graphrag/index/reporting/load_pipeline_reporter.py new file mode 100644 index 0000000000..0386ea03d1 --- /dev/null +++ b/func-app/graphrag/index/reporting/load_pipeline_reporter.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Load pipeline reporter method.""" + +from pathlib import Path +from typing import cast + +from datashaper import WorkflowCallbacks + +from graphrag.config import ReportingType +from graphrag.index.config import ( + PipelineBlobReportingConfig, + PipelineFileReportingConfig, + PipelineReportingConfig, +) + +from .blob_workflow_callbacks import BlobWorkflowCallbacks +from .console_workflow_callbacks import ConsoleWorkflowCallbacks +from .file_workflow_callbacks import FileWorkflowCallbacks + + +def load_pipeline_reporter( + config: PipelineReportingConfig | None, root_dir: str | None +) -> WorkflowCallbacks: + """Create a reporter for the given pipeline config.""" + config = config or PipelineFileReportingConfig(base_dir="reports") + + match config.type: + case ReportingType.file: + config = cast(PipelineFileReportingConfig, config) + return FileWorkflowCallbacks( + str(Path(root_dir or "") / (config.base_dir or "")) + ) + case ReportingType.console: + return ConsoleWorkflowCallbacks() + case ReportingType.blob: + config = cast(PipelineBlobReportingConfig, config) + return BlobWorkflowCallbacks( + config.connection_string, + config.container_name, + base_dir=config.base_dir, + storage_account_blob_url=config.storage_account_blob_url, + ) + case _: + msg = f"Unknown reporting type: {config.type}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/reporting/progress_workflow_callbacks.py b/func-app/graphrag/index/reporting/progress_workflow_callbacks.py new file mode 100644 index 0000000000..68e407e223 --- /dev/null +++ b/func-app/graphrag/index/reporting/progress_workflow_callbacks.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A workflow callback manager that emits updates to a ProgressReporter.""" + +from typing import Any + +from datashaper import ExecutionNode, NoopWorkflowCallbacks, Progress, TableContainer + +from graphrag.common.progress import ProgressReporter + + +class ProgressWorkflowCallbacks(NoopWorkflowCallbacks): + """A callbackmanager that delegates to a ProgressReporter.""" + + _root_progress: ProgressReporter + _progress_stack: list[ProgressReporter] + + def __init__(self, progress: ProgressReporter) -> None: + """Create a new ProgressWorkflowCallbacks.""" + self._progress = progress + self._progress_stack = [progress] + + def _pop(self) -> None: + self._progress_stack.pop() + + def _push(self, name: str) -> None: + self._progress_stack.append(self._latest.child(name)) + + @property + def _latest(self) -> ProgressReporter: + return self._progress_stack[-1] + + def on_workflow_start(self, name: str, instance: object) -> None: + """Execute this callback when a workflow starts.""" + self._push(name) + + def on_workflow_end(self, name: str, instance: object) -> None: + """Execute this callback when a workflow ends.""" + self._pop() + + def on_step_start(self, node: ExecutionNode, inputs: dict[str, Any]) -> None: + """Execute this callback every time a step starts.""" + verb_id_str = f" ({node.node_id})" if node.has_explicit_id else "" + self._push(f"Verb {node.verb.name}{verb_id_str}") + self._latest(Progress(percent=0)) + + def on_step_end(self, node: ExecutionNode, result: TableContainer | None) -> None: + """Execute this callback every time a step ends.""" + self._pop() + + def on_step_progress(self, node: ExecutionNode, progress: Progress) -> None: + """Handle when progress occurs.""" + self._latest(progress) diff --git a/func-app/graphrag/index/run.py b/func-app/graphrag/index/run.py new file mode 100644 index 0000000000..27416f7d7f --- /dev/null +++ b/func-app/graphrag/index/run.py @@ -0,0 +1,471 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Different methods to run the pipeline.""" + +import gc +import json +import logging +import time +import traceback +from collections.abc import AsyncIterable +from dataclasses import asdict +from io import BytesIO +from pathlib import Path +from string import Template +from typing import cast + +import pandas as pd +from datashaper import ( + DEFAULT_INPUT_NAME, + MemoryProfile, + Workflow, + WorkflowCallbacks, + WorkflowCallbacksManager, + WorkflowRunResult, +) +from graphrag.config.models.graphdb_config import GraphDBConfig + +from .cache import InMemoryCache, PipelineCache, load_cache +from .config import ( + PipelineBlobCacheConfig, + PipelineBlobReportingConfig, + PipelineBlobStorageConfig, + PipelineCacheConfigTypes, + PipelineConfig, + PipelineFileCacheConfig, + PipelineFileReportingConfig, + PipelineFileStorageConfig, + PipelineInputConfigTypes, + PipelineMemoryCacheConfig, + PipelineReportingConfigTypes, + PipelineStorageConfigTypes, + PipelineWorkflowReference, + PipelineWorkflowStep, +) +from .context import PipelineRunContext, PipelineRunStats +from .emit import TableEmitterType, create_table_emitters +from .input import load_input +from .load_pipeline_config import load_pipeline_config +from graphrag.common.progress import NullProgressReporter, ProgressReporter +from .reporting import ( + ConsoleWorkflowCallbacks, + ProgressWorkflowCallbacks, + load_pipeline_reporter, +) +from graphrag.common.storage import MemoryPipelineStorage, PipelineStorage, load_storage +from .typing import PipelineRunResult + +# Register all verbs +from .verbs import * # noqa +from .workflows import ( + VerbDefinitions, + WorkflowDefinitions, + create_workflow, + load_workflows, +) + +log = logging.getLogger(__name__) + + +async def run_pipeline_with_config( + config_or_path: PipelineConfig | str, + workflows: list[PipelineWorkflowReference] | None = None, + dataset: pd.DataFrame | None = None, + storage: PipelineStorage | None = None, + cache: PipelineCache | None = None, + callbacks: WorkflowCallbacks | None = None, + progress_reporter: ProgressReporter | None = None, + input_post_process_steps: list[PipelineWorkflowStep] | None = None, + additional_verbs: VerbDefinitions | None = None, + additional_workflows: WorkflowDefinitions | None = None, + emit: list[TableEmitterType] | None = None, + memory_profile: bool = False, + run_id: str | None = None, + is_resume_run: bool = False, + context_id: str | None = None, + **_kwargs: dict, +) -> AsyncIterable[PipelineRunResult]: + """Run a pipeline with the given config. + + Args: + - config_or_path - The config to run the pipeline with + - workflows - The workflows to run (this overrides the config) + - dataset - The dataset to run the pipeline on (this overrides the config) + - storage - The storage to use for the pipeline (this overrides the config) + - cache - The cache to use for the pipeline (this overrides the config) + - reporter - The reporter to use for the pipeline (this overrides the config) + - input_post_process_steps - The post process steps to run on the input data (this overrides the config) + - additional_verbs - The custom verbs to use for the pipeline. + - additional_workflows - The custom workflows to use for the pipeline. + - emit - The table emitters to use for the pipeline. + - memory_profile - Whether or not to profile the memory. + - run_id - The run id to start or resume from. + """ + if isinstance(config_or_path, str): + log.info("Running pipeline with config %s", config_or_path) + else: + log.info("Running pipeline") + + run_id = run_id or time.strftime("%Y%m%d-%H%M%S") + config = load_pipeline_config(config_or_path) + config = _apply_substitutions(config, run_id) + root_dir = config.root_dir + + def _create_storage(config: PipelineStorageConfigTypes | None) -> PipelineStorage: + return load_storage( + config + or PipelineFileStorageConfig(base_dir=str(Path(root_dir or "") / "output")) + ) + + def _create_cache(config: PipelineCacheConfigTypes | None) -> PipelineCache: + return load_cache(config or PipelineMemoryCacheConfig(), root_dir=root_dir) + + def _create_reporter( + config: PipelineReportingConfigTypes | None, + ) -> WorkflowCallbacks | None: + return load_pipeline_reporter(config, root_dir) if config else None + + async def _create_input( + config: PipelineInputConfigTypes | None, + ) -> pd.DataFrame | None: + if config is None: + return None + + return await load_input(config, progress_reporter, root_dir) + + def _create_postprocess_steps( + config: PipelineInputConfigTypes | None, + ) -> list[PipelineWorkflowStep] | None: + return config.post_process if config is not None else None + + progress_reporter = progress_reporter or NullProgressReporter() + storage = storage or _create_storage(config.storage) + cache = cache or _create_cache(config.cache) + callbacks = callbacks or _create_reporter(config.reporting) + dataset = dataset if dataset is not None else await _create_input(config.input) + post_process_steps = input_post_process_steps or _create_postprocess_steps( + config.input + ) + workflows = workflows or config.workflows + + if dataset is None: + msg = "No dataset provided!" + raise ValueError(msg) + + async for table in run_pipeline( + workflows=workflows, + dataset=dataset, + storage=storage, + cache=cache, + callbacks=callbacks, + input_post_process_steps=post_process_steps, + memory_profile=memory_profile, + additional_verbs=additional_verbs, + additional_workflows=additional_workflows, + progress_reporter=progress_reporter, + emit=emit, + is_resume_run=is_resume_run, + graphdb_params = config.graphdb_params, + context_id=context_id, + ): + yield table + + +async def run_pipeline( + workflows: list[PipelineWorkflowReference], + dataset: pd.DataFrame, + storage: PipelineStorage | None = None, + cache: PipelineCache | None = None, + callbacks: WorkflowCallbacks | None = None, + progress_reporter: ProgressReporter | None = None, + input_post_process_steps: list[PipelineWorkflowStep] | None = None, + additional_verbs: VerbDefinitions | None = None, + additional_workflows: WorkflowDefinitions | None = None, + emit: list[TableEmitterType] | None = None, + memory_profile: bool = False, + is_resume_run: bool = False, + graphdb_params: GraphDBConfig|None = None, + context_id: str | None = None, + **_kwargs: dict, +) -> AsyncIterable[PipelineRunResult]: + """Run the pipeline. + + Args: + - workflows - The workflows to run + - dataset - The dataset to run the pipeline on, specifically a dataframe with the following columns at a minimum: + - id - The id of the document + - text - The text of the document + - title - The title of the document + These must exist after any post process steps are run if there are any! + - storage - The storage to use for the pipeline + - cache - The cache to use for the pipeline + - reporter - The reporter to use for the pipeline + - input_post_process_steps - The post process steps to run on the input data + - additional_verbs - The custom verbs to use for the pipeline + - additional_workflows - The custom workflows to use for the pipeline + - debug - Whether or not to run in debug mode + Returns: + - output - An iterable of workflow results as they complete running, as well as any errors that occur + """ + start_time = time.time() + stats = PipelineRunStats() + storage = storage or MemoryPipelineStorage() + cache = cache or InMemoryCache() + progress_reporter = progress_reporter or NullProgressReporter() + callbacks = callbacks or ConsoleWorkflowCallbacks() + callbacks = _create_callback_chain(callbacks, progress_reporter) + emit = emit or [TableEmitterType.Parquet] + emitters = create_table_emitters( + emit, + storage, + lambda e, s, d: cast(WorkflowCallbacks, callbacks).on_error( + "Error emitting table", e, s, d + ), + graphdb_params, + context_id, + ) + loaded_workflows = load_workflows( + workflows, + additional_verbs=additional_verbs, + additional_workflows=additional_workflows, + memory_profile=memory_profile, + ) + workflows_to_run = loaded_workflows.workflows + workflow_dependencies = loaded_workflows.dependencies + + context = _create_run_context(storage, cache, stats) + + if len(emitters) == 0: + log.info( + "No emitters provided. No table outputs will be generated. This is probably not correct." + ) + + async def dump_stats() -> None: + await storage.set("stats.json", json.dumps(asdict(stats), indent=4)) + + async def load_table_from_storage(name: str) -> pd.DataFrame: + if not await storage.has(name): + msg = f"Could not find {name} in storage!" + raise ValueError(msg) + try: + log.info("read table from storage: %s", name) + return pd.read_parquet(BytesIO(await storage.get(name, as_bytes=True))) + except Exception: + log.exception("error loading table from storage: %s", name) + raise + + async def inject_workflow_data_dependencies(workflow: Workflow) -> None: + workflow.add_table(DEFAULT_INPUT_NAME, dataset) + deps = workflow_dependencies[workflow.name] + log.info("dependencies for %s: %s", workflow.name, deps) + for id in deps: + workflow_id = f"workflow:{id}" + table = await load_table_from_storage(f"{id}.parquet") + workflow.add_table(workflow_id, table) + + async def write_workflow_stats( + workflow: Workflow, + workflow_result: WorkflowRunResult, + workflow_start_time: float, + ) -> None: + for vt in workflow_result.verb_timings: + stats.workflows[workflow.name][f"{vt.index}_{vt.verb}"] = vt.timing + + workflow_end_time = time.time() + stats.workflows[workflow.name]["overall"] = ( + workflow_end_time - workflow_start_time + ) + stats.total_runtime = time.time() - start_time + await dump_stats() + + if workflow_result.memory_profile is not None: + await _save_profiler_stats( + storage, workflow.name, workflow_result.memory_profile + ) + + log.debug( + "first row of %s => %s", workflow_name, workflow.output().iloc[0].to_json() + ) + + async def emit_workflow_output(workflow: Workflow) -> pd.DataFrame: + output = cast(pd.DataFrame, workflow.output()) + for emitter in emitters: + await emitter.emit(workflow.name, output) + return output + + dataset = await _run_post_process_steps( + input_post_process_steps, dataset, context, callbacks + ) + + # Make sure the incoming data is valid + _validate_dataset(dataset) + + log.info("Final # of rows loaded: %s", len(dataset)) + stats.num_documents = len(dataset) + last_workflow = "input" + + try: + await dump_stats() + + for workflow_to_run in workflows_to_run: + # Try to flush out any intermediate dataframes + gc.collect() + + workflow = workflow_to_run.workflow + workflow_name: str = workflow.name + last_workflow = workflow_name + + log.info("Running workflow: %s...", workflow_name) + + if is_resume_run and await storage.has( + f"{workflow_to_run.workflow.name}.parquet" + ): + log.info("Skipping %s because it already exists", workflow_name) + continue + + stats.workflows[workflow_name] = {"overall": 0.0} + await inject_workflow_data_dependencies(workflow) + + workflow_start_time = time.time() + result = await workflow.run(context, callbacks) + await write_workflow_stats(workflow, result, workflow_start_time) + + # Save the output from the workflow + output = await emit_workflow_output(workflow) + yield PipelineRunResult(workflow_name, output, None) + output = None + workflow.dispose() + workflow = None + + stats.total_runtime = time.time() - start_time + await dump_stats() + except Exception as e: + log.exception("error running workflow %s", last_workflow) + cast(WorkflowCallbacks, callbacks).on_error( + "Error running pipeline!", e, traceback.format_exc() + ) + yield PipelineRunResult(last_workflow, None, [e]) + + +def _create_callback_chain( + callbacks: WorkflowCallbacks | None, progress: ProgressReporter | None +) -> WorkflowCallbacks: + """Create a callbacks manager.""" + manager = WorkflowCallbacksManager() + if callbacks is not None: + manager.register(callbacks) + if progress is not None: + manager.register(ProgressWorkflowCallbacks(progress)) + return manager + + +async def _save_profiler_stats( + storage: PipelineStorage, workflow_name: str, profile: MemoryProfile +): + """Save the profiler stats to the storage.""" + await storage.set( + f"{workflow_name}_profiling.peak_stats.csv", + profile.peak_stats.to_csv(index=True), + ) + + await storage.set( + f"{workflow_name}_profiling.snapshot_stats.csv", + profile.snapshot_stats.to_csv(index=True), + ) + + await storage.set( + f"{workflow_name}_profiling.time_stats.csv", + profile.time_stats.to_csv(index=True), + ) + + await storage.set( + f"{workflow_name}_profiling.detailed_view.csv", + profile.detailed_view.to_csv(index=True), + ) + + +async def _run_post_process_steps( + post_process: list[PipelineWorkflowStep] | None, + dataset: pd.DataFrame, + context: PipelineRunContext, + callbacks: WorkflowCallbacks, +) -> pd.DataFrame: + """Run the pipeline. + + Args: + - post_process - The post process steps to run + - dataset - The dataset to run the steps on + - context - The pipeline run context + Returns: + - output - The dataset after running the post process steps + """ + if post_process is not None and len(post_process) > 0: + input_workflow = create_workflow( + "Input Post Process", + post_process, + ) + input_workflow.add_table(DEFAULT_INPUT_NAME, dataset) + await input_workflow.run( + context=context, + callbacks=callbacks, + ) + dataset = cast(pd.DataFrame, input_workflow.output()) + return dataset + + +def _validate_dataset(dataset: pd.DataFrame): + """Validate the dataset for the pipeline. + + Args: + - dataset - The dataset to validate + """ + if not isinstance(dataset, pd.DataFrame): + msg = "Dataset must be a pandas dataframe!" + raise TypeError(msg) + + +def _apply_substitutions(config: PipelineConfig, run_id: str) -> PipelineConfig: + substitutions = {"timestamp": run_id} + + if ( + isinstance( + config.storage, PipelineFileStorageConfig | PipelineBlobStorageConfig + ) + and config.storage.base_dir + ): + config.storage.base_dir = Template(config.storage.base_dir).substitute( + substitutions + ) + if ( + isinstance(config.cache, PipelineFileCacheConfig | PipelineBlobCacheConfig) + and config.cache.base_dir + ): + config.cache.base_dir = Template(config.cache.base_dir).substitute( + substitutions + ) + + if ( + isinstance( + config.reporting, PipelineFileReportingConfig | PipelineBlobReportingConfig + ) + and config.reporting.base_dir + ): + config.reporting.base_dir = Template(config.reporting.base_dir).substitute( + substitutions + ) + + return config + + +def _create_run_context( + storage: PipelineStorage, + cache: PipelineCache, + stats: PipelineRunStats, +) -> PipelineRunContext: + """Create the run context for the pipeline.""" + return PipelineRunContext( + stats=stats, + cache=cache, + storage=storage, + ) diff --git a/func-app/graphrag/index/text_splitting/__init__.py b/func-app/graphrag/index/text_splitting/__init__.py new file mode 100644 index 0000000000..4653adb22b --- /dev/null +++ b/func-app/graphrag/index/text_splitting/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine Text Splitting package root.""" + +from .check_token_limit import check_token_limit +from .text_splitting import ( + DecodeFn, + EncodedText, + EncodeFn, + LengthFn, + NoopTextSplitter, + TextListSplitter, + TextListSplitterType, + TextSplitter, + Tokenizer, + TokenTextSplitter, + split_text_on_tokens, +) + +__all__ = [ + "DecodeFn", + "EncodeFn", + "EncodedText", + "LengthFn", + "NoopTextSplitter", + "TextListSplitter", + "TextListSplitterType", + "TextSplitter", + "TokenTextSplitter", + "Tokenizer", + "check_token_limit", + "split_text_on_tokens", +] diff --git a/func-app/graphrag/index/text_splitting/check_token_limit.py b/func-app/graphrag/index/text_splitting/check_token_limit.py new file mode 100644 index 0000000000..1a5f862254 --- /dev/null +++ b/func-app/graphrag/index/text_splitting/check_token_limit.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Token limit method definition.""" + +from .text_splitting import TokenTextSplitter + + +def check_token_limit(text, max_token): + """Check token limit.""" + text_splitter = TokenTextSplitter(chunk_size=max_token, chunk_overlap=0) + docs = text_splitter.split_text(text) + if len(docs) > 1: + return 0 + return 1 diff --git a/func-app/graphrag/index/text_splitting/text_splitting.py b/func-app/graphrag/index/text_splitting/text_splitting.py new file mode 100644 index 0000000000..0badc8977c --- /dev/null +++ b/func-app/graphrag/index/text_splitting/text_splitting.py @@ -0,0 +1,244 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models.""" + +import json +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable, Collection, Iterable +from dataclasses import dataclass +from enum import Enum +from typing import Any, Literal, cast + +import pandas as pd +import tiktoken + +from graphrag.index.utils import num_tokens_from_string + +EncodedText = list[int] +DecodeFn = Callable[[EncodedText], str] +EncodeFn = Callable[[str], EncodedText] +LengthFn = Callable[[str], int] + +log = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class Tokenizer: + """Tokenizer data class.""" + + chunk_overlap: int + """Overlap in tokens between chunks""" + tokens_per_chunk: int + """Maximum number of tokens per chunk""" + decode: DecodeFn + """ Function to decode a list of token ids to a string""" + encode: EncodeFn + """ Function to encode a string to a list of token ids""" + + +class TextSplitter(ABC): + """Text splitter class definition.""" + + _chunk_size: int + _chunk_overlap: int + _length_function: LengthFn + _keep_separator: bool + _add_start_index: bool + _strip_whitespace: bool + + def __init__( + self, + # based on text-ada-002-embedding max input buffer length + # https://platform.openai.com/docs/guides/embeddings/second-generation-models + chunk_size: int = 8191, + chunk_overlap: int = 100, + length_function: LengthFn = len, + keep_separator: bool = False, + add_start_index: bool = False, + strip_whitespace: bool = True, + ): + """Init method definition.""" + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._length_function = length_function + self._keep_separator = keep_separator + self._add_start_index = add_start_index + self._strip_whitespace = strip_whitespace + + @abstractmethod + def split_text(self, text: str | list[str]) -> Iterable[str]: + """Split text method definition.""" + + +class NoopTextSplitter(TextSplitter): + """Noop text splitter class definition.""" + + def split_text(self, text: str | list[str]) -> Iterable[str]: + """Split text method definition.""" + return [text] if isinstance(text, str) else text + + +class TokenTextSplitter(TextSplitter): + """Token text splitter class definition.""" + + _allowed_special: Literal["all"] | set[str] + _disallowed_special: Literal["all"] | Collection[str] + + def __init__( + self, + encoding_name: str = "cl100k_base", + model_name: str | None = None, + allowed_special: Literal["all"] | set[str] | None = None, + disallowed_special: Literal["all"] | Collection[str] = "all", + **kwargs: Any, + ): + """Init method definition.""" + super().__init__(**kwargs) + if model_name is not None: + try: + enc = tiktoken.encoding_for_model(model_name) + except KeyError: + log.exception("Model %s not found, using %s", model_name, encoding_name) + enc = tiktoken.get_encoding(encoding_name) + else: + enc = tiktoken.get_encoding(encoding_name) + self._tokenizer = enc + self._allowed_special = allowed_special or set() + self._disallowed_special = disallowed_special + + def encode(self, text: str) -> list[int]: + """Encode the given text into an int-vector.""" + return self._tokenizer.encode( + text, + allowed_special=self._allowed_special, + disallowed_special=self._disallowed_special, + ) + + def num_tokens(self, text: str) -> int: + """Return the number of tokens in a string.""" + return len(self.encode(text)) + + def split_text(self, text: str | list[str]) -> list[str]: + """Split text method.""" + if cast(bool, pd.isna(text)) or text == "": + return [] + if isinstance(text, list): + text = " ".join(text) + if not isinstance(text, str): + msg = f"Attempting to split a non-string value, actual is {type(text)}" + raise TypeError(msg) + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self._chunk_size, + decode=self._tokenizer.decode, + encode=lambda text: self.encode(text), + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + +class TextListSplitterType(str, Enum): + """Enum for the type of the TextListSplitter.""" + + DELIMITED_STRING = "delimited_string" + JSON = "json" + + +class TextListSplitter(TextSplitter): + """Text list splitter class definition.""" + + def __init__( + self, + chunk_size: int, + splitter_type: TextListSplitterType = TextListSplitterType.JSON, + input_delimiter: str | None = None, + output_delimiter: str | None = None, + model_name: str | None = None, + encoding_name: str | None = None, + ): + """Initialize the TextListSplitter with a chunk size.""" + # Set the chunk overlap to 0 as we use full strings + super().__init__(chunk_size, chunk_overlap=0) + self._type = splitter_type + self._input_delimiter = input_delimiter + self._output_delimiter = output_delimiter or "\n" + self._length_function = lambda x: num_tokens_from_string( + x, model=model_name, encoding_name=encoding_name + ) + + def split_text(self, text: str | list[str]) -> Iterable[str]: + """Split a string list into a list of strings for a given chunk size.""" + if not text: + return [] + + result: list[str] = [] + current_chunk: list[str] = [] + + # Add the brackets + current_length: int = self._length_function("[]") + + # Input should be a string list joined by a delimiter + string_list = self._load_text_list(text) + + if len(string_list) == 1: + return string_list + + for item in string_list: + # Count the length of the item and add comma + item_length = self._length_function(f"{item},") + + if current_length + item_length > self._chunk_size: + if current_chunk and len(current_chunk) > 0: + # Add the current chunk to the result + self._append_to_result(result, current_chunk) + + # Start a new chunk + current_chunk = [item] + # Add 2 for the brackets + current_length = item_length + else: + # Add the item to the current chunk + current_chunk.append(item) + # Add 1 for the comma + current_length += item_length + + # Add the last chunk to the result + self._append_to_result(result, current_chunk) + + return result + + def _load_text_list(self, text: str | list[str]): + """Load the text list based on the type.""" + if isinstance(text, list): + string_list = text + elif self._type == TextListSplitterType.JSON: + string_list = json.loads(text) + else: + string_list = text.split(self._input_delimiter) + return string_list + + def _append_to_result(self, chunk_list: list[str], new_chunk: list[str]): + """Append the current chunk to the result.""" + if new_chunk and len(new_chunk) > 0: + if self._type == TextListSplitterType.JSON: + chunk_list.append(json.dumps(new_chunk)) + else: + chunk_list.append(self._output_delimiter.join(new_chunk)) + + +def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: + """Split incoming text and return chunks using tokenizer.""" + splits: list[str] = [] + input_ids = tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(tokenizer.decode(chunk_ids)) + start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits diff --git a/func-app/graphrag/index/typing.py b/func-app/graphrag/index/typing.py new file mode 100644 index 0000000000..ed1d7e93e7 --- /dev/null +++ b/func-app/graphrag/index/typing.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the 'PipelineRunResult' model.""" + +from collections.abc import Callable +from dataclasses import dataclass + +import pandas as pd + +ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] + + +@dataclass +class PipelineRunResult: + """Pipeline run result class definition.""" + + workflow: str + result: pd.DataFrame | None + errors: list[BaseException] | None diff --git a/func-app/graphrag/index/utils/__init__.py b/func-app/graphrag/index/utils/__init__.py new file mode 100644 index 0000000000..7cbbb53d75 --- /dev/null +++ b/func-app/graphrag/index/utils/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utils methods definition.""" + +from .dicts import dict_has_keys_with_types +from .hashing import gen_md5_hash +from .is_null import is_null +from .load_graph import load_graph +from .string import clean_str +from .tokens import num_tokens_from_string, string_from_tokens +from .topological_sort import topological_sort +from .uuid import gen_uuid + +__all__ = [ + "clean_str", + "dict_has_keys_with_types", + "gen_md5_hash", + "gen_uuid", + "is_null", + "load_graph", + "num_tokens_from_string", + "string_from_tokens", + "topological_sort", +] diff --git a/func-app/graphrag/index/utils/dataframes.py b/func-app/graphrag/index/utils/dataframes.py new file mode 100644 index 0000000000..ea65d71d7a --- /dev/null +++ b/func-app/graphrag/index/utils/dataframes.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing DataFrame utilities.""" + +from collections.abc import Callable +from typing import Any, cast + +import pandas as pd +from pandas._typing import MergeHow + + +def drop_columns(df: pd.DataFrame, *column: str) -> pd.DataFrame: + """Drop columns from a dataframe.""" + return df.drop(list(column), axis=1) + + +def where_column_equals(df: pd.DataFrame, column: str, value: Any) -> pd.DataFrame: + """Return a filtered DataFrame where a column equals a value.""" + return cast(pd.DataFrame, df[df[column] == value]) + + +def antijoin(df: pd.DataFrame, exclude: pd.DataFrame, column: str) -> pd.DataFrame: + """Return an anti-joined dataframe. + + Arguments: + * df: The DataFrame to apply the exclusion to + * exclude: The DataFrame containing rows to remove. + * column: The join-on column. + """ + result = df.merge( + exclude[[column]], + on=column, + how="outer", + indicator=True, + ) + if "_merge" in result.columns: + result = result[result["_merge"] == "left_only"].drop("_merge", axis=1) + return cast(pd.DataFrame, result) + + +def transform_series(series: pd.Series, fn: Callable[[Any], Any]) -> pd.Series: + """Apply a transformation function to a series.""" + return cast(pd.Series, series.apply(fn)) + + +def join( + left: pd.DataFrame, right: pd.DataFrame, key: str, strategy: MergeHow = "left" +) -> pd.DataFrame: + """Perform a table join.""" + return left.merge(right, on=key, how=strategy) + + +def union(*frames: pd.DataFrame) -> pd.DataFrame: + """Perform a union operation on the given set of dataframes.""" + return pd.concat(list(frames)) + + +def select(df: pd.DataFrame, *columns: str) -> pd.DataFrame: + """Select columns from a dataframe.""" + return cast(pd.DataFrame, df[list(columns)]) diff --git a/func-app/graphrag/index/utils/dicts.py b/func-app/graphrag/index/utils/dicts.py new file mode 100644 index 0000000000..4d3662e0b8 --- /dev/null +++ b/func-app/graphrag/index/utils/dicts.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A utility module containing methods for inspecting and verifying dictionary types.""" + + +def dict_has_keys_with_types( + data: dict, expected_fields: list[tuple[str, type]] +) -> bool: + """Return True if the given dictionary has the given keys with the given types.""" + for field, field_type in expected_fields: + if field not in data: + return False + + value = data[field] + if not isinstance(value, field_type): + return False + return True diff --git a/func-app/graphrag/index/utils/ds_util.py b/func-app/graphrag/index/utils/ds_util.py new file mode 100644 index 0000000000..d65c69e4a8 --- /dev/null +++ b/func-app/graphrag/index/utils/ds_util.py @@ -0,0 +1,32 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A utility module datashaper-specific utility methods.""" + +from typing import cast + +from datashaper import TableContainer, VerbInput + +_NAMED_INPUTS_REQUIRED = "Named inputs are required" + + +def get_required_input_table(input: VerbInput, name: str) -> TableContainer: + """Get a required input table by name.""" + return cast(TableContainer, get_named_input_table(input, name, required=True)) + + +def get_named_input_table( + input: VerbInput, name: str, required: bool = False +) -> TableContainer | None: + """Get an input table from datashaper verb-inputs by name.""" + named_inputs = input.named + if named_inputs is None: + if not required: + return None + raise ValueError(_NAMED_INPUTS_REQUIRED) + + result = named_inputs.get(name) + if result is None and required: + msg = f"input '${name}' is required" + raise ValueError(msg) + return result diff --git a/func-app/graphrag/index/utils/hashing.py b/func-app/graphrag/index/utils/hashing.py new file mode 100644 index 0000000000..342ae99d44 --- /dev/null +++ b/func-app/graphrag/index/utils/hashing.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Hashing utilities.""" + +from collections.abc import Iterable +from hashlib import md5 +from typing import Any + + +def gen_md5_hash(item: dict[str, Any], hashcode: Iterable[str]): + """Generate an md5 hash.""" + hashed = "".join([str(item[column]) for column in hashcode]) + return f"{md5(hashed.encode('utf-8'), usedforsecurity=False).hexdigest()}" diff --git a/func-app/graphrag/index/utils/is_null.py b/func-app/graphrag/index/utils/is_null.py new file mode 100644 index 0000000000..f5df1955ae --- /dev/null +++ b/func-app/graphrag/index/utils/is_null.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Defines the is_null utility.""" + +import math +from typing import Any + + +def is_null(value: Any) -> bool: + """Check if value is null or is nan.""" + + def is_none() -> bool: + return value is None + + def is_nan() -> bool: + return isinstance(value, float) and math.isnan(value) + + return is_none() or is_nan() diff --git a/func-app/graphrag/index/utils/load_graph.py b/func-app/graphrag/index/utils/load_graph.py new file mode 100644 index 0000000000..57992a04c8 --- /dev/null +++ b/func-app/graphrag/index/utils/load_graph.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Networkx load_graph utility definition.""" + +import networkx as nx + + +def load_graph(graphml: str | nx.Graph) -> nx.Graph: + """Load a graph from a graphml file or a networkx graph.""" + return nx.parse_graphml(graphml) if isinstance(graphml, str) else graphml diff --git a/func-app/graphrag/index/utils/rate_limiter.py b/func-app/graphrag/index/utils/rate_limiter.py new file mode 100644 index 0000000000..8dc641719b --- /dev/null +++ b/func-app/graphrag/index/utils/rate_limiter.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Rate limiter utility.""" + +import asyncio +import time + + +class RateLimiter: + """ + The original TpmRpmLLMLimiter strategy did not account for minute-based rate limiting when scheduled. + + The RateLimiter was introduced to ensure that the CommunityReportsExtractor could be scheduled to adhere to rate configurations on a per-minute basis. + """ + + # TODO: RateLimiter scheduled: using asyncio for async_mode + + def __init__(self, rate: int, per: int): + self.rate = rate + self.per = per + self.allowance = rate + self.last_check = time.monotonic() + + async def acquire(self): + """Acquire a token from the rate limiter.""" + current = time.monotonic() + elapsed = current - self.last_check + self.last_check = current + self.allowance += elapsed * (self.rate / self.per) + + if self.allowance > self.rate: + self.allowance = self.rate + + if self.allowance < 1.0: + sleep_time = (1.0 - self.allowance) * (self.per / self.rate) + await asyncio.sleep(sleep_time) + self.allowance = 0.0 + else: + self.allowance -= 1.0 diff --git a/func-app/graphrag/index/utils/string.py b/func-app/graphrag/index/utils/string.py new file mode 100644 index 0000000000..7e1654bb4e --- /dev/null +++ b/func-app/graphrag/index/utils/string.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""String utilities.""" + +import html +import re +from typing import Any + + +def clean_str(input: Any) -> str: + """Clean an input string by removing HTML escapes, control characters, and other unwanted characters.""" + # If we get non-string input, just give it back + if not isinstance(input, str): + return input + + result = html.unescape(input.strip()) + # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python + return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) diff --git a/func-app/graphrag/index/utils/tokens.py b/func-app/graphrag/index/utils/tokens.py new file mode 100644 index 0000000000..4a189b9b22 --- /dev/null +++ b/func-app/graphrag/index/utils/tokens.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utilities for working with tokens.""" + +import logging + +import tiktoken + +DEFAULT_ENCODING_NAME = "cl100k_base" +log = logging.getLogger(__name__) + + +def num_tokens_from_string( + string: str, model: str | None = None, encoding_name: str | None = None +) -> int: + """Return the number of tokens in a text string.""" + if model is not None: + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}" + log.warning(msg) + encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME) + else: + encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME) + return len(encoding.encode(string)) + + +def string_from_tokens( + tokens: list[int], model: str | None = None, encoding_name: str | None = None +) -> str: + """Return a text string from a list of tokens.""" + if model is not None: + encoding = tiktoken.encoding_for_model(model) + elif encoding_name is not None: + encoding = tiktoken.get_encoding(encoding_name) + else: + msg = "Either model or encoding_name must be specified." + raise ValueError(msg) + return encoding.decode(tokens) diff --git a/func-app/graphrag/index/utils/topological_sort.py b/func-app/graphrag/index/utils/topological_sort.py new file mode 100644 index 0000000000..a19b464559 --- /dev/null +++ b/func-app/graphrag/index/utils/topological_sort.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Topological sort utility method.""" + +from graphlib import TopologicalSorter + + +def topological_sort(graph: dict[str, list[str]]) -> list[str]: + """Topological sort.""" + ts = TopologicalSorter(graph) + return list(ts.static_order()) diff --git a/func-app/graphrag/index/utils/uuid.py b/func-app/graphrag/index/utils/uuid.py new file mode 100644 index 0000000000..0671fb09da --- /dev/null +++ b/func-app/graphrag/index/utils/uuid.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""UUID utilities.""" + +import uuid +from random import Random, getrandbits + + +def gen_uuid(rd: Random | None = None): + """Generate a random UUID v4.""" + return uuid.UUID( + int=rd.getrandbits(128) if rd is not None else getrandbits(128), version=4 + ).hex diff --git a/func-app/graphrag/index/verbs/__init__.py b/func-app/graphrag/index/verbs/__init__.py new file mode 100644 index 0000000000..379c2a3749 --- /dev/null +++ b/func-app/graphrag/index/verbs/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing get_default_verbs method definition.""" + +from .covariates import extract_covariates +from .entities import entity_extract, summarize_descriptions +from .genid import genid +from .graph import ( + cluster_graph, + create_community_reports, + create_graph, + embed_graph, + layout_graph, + merge_graphs, + unpack_graph, +) +from .overrides import aggregate, concat, merge +from .snapshot import snapshot +from .snapshot_rows import snapshot_rows +from .spread_json import spread_json +from .text import chunk, text_embed, text_split, text_translate +from .unzip import unzip +from .zip import zip_verb + +__all__ = [ + "aggregate", + "chunk", + "cluster_graph", + "concat", + "create_community_reports", + "create_graph", + "embed_graph", + "entity_extract", + "extract_covariates", + "genid", + "layout_graph", + "merge", + "merge_graphs", + "snapshot", + "snapshot_rows", + "spread_json", + "summarize_descriptions", + "text_embed", + "text_split", + "text_translate", + "unpack_graph", + "unzip", + "zip_verb", +] diff --git a/func-app/graphrag/index/verbs/covariates/__init__.py b/func-app/graphrag/index/verbs/covariates/__init__.py new file mode 100644 index 0000000000..cdebee228b --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine covariates package root.""" + +from .extract_covariates import extract_covariates + +__all__ = ["extract_covariates"] diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/__init__.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/__init__.py new file mode 100644 index 0000000000..53d357bb46 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text extract claims package root.""" + +from .extract_covariates import ExtractClaimsStrategyType, extract_covariates + +__all__ = ["ExtractClaimsStrategyType", "extract_covariates"] diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py new file mode 100644 index 0000000000..a67cb0fa0e --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the extract_covariates verb definition.""" + +import logging +from dataclasses import asdict +from enum import Enum +from typing import Any, cast + +import pandas as pd +from datashaper import ( + AsyncType, + TableContainer, + VerbCallbacks, + VerbInput, + derive_from_rows, + verb, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.verbs.covariates.typing import Covariate, CovariateExtractStrategy + +log = logging.getLogger(__name__) + + +class ExtractClaimsStrategyType(str, Enum): + """ExtractClaimsStrategyType class definition.""" + + graph_intelligence = "graph_intelligence" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] + + +@verb(name="extract_covariates") +async def extract_covariates( + input: VerbInput, + cache: PipelineCache, + callbacks: VerbCallbacks, + column: str, + covariate_type: str, + strategy: dict[str, Any] | None, + async_mode: AsyncType = AsyncType.AsyncIO, + entity_types: list[str] | None = None, + **kwargs, +) -> TableContainer: + """ + Extract claims from a piece of text. + + ## Usage + TODO + """ + log.debug("extract_covariates strategy=%s", strategy) + if entity_types is None: + entity_types = DEFAULT_ENTITY_TYPES + output = cast(pd.DataFrame, input.get_input()) + + resolved_entities_map = {} + + strategy = strategy or {} + strategy_exec = load_strategy( + strategy.get("type", ExtractClaimsStrategyType.graph_intelligence) + ) + strategy_config = {**strategy} + + async def run_strategy(row): + text = row[column] + result = await strategy_exec( + text, entity_types, resolved_entities_map, callbacks, cache, strategy_config + ) + return [ + create_row_from_claim_data(row, item, covariate_type) + for item in result.covariate_data + ] + + results = await derive_from_rows( + output, + run_strategy, + callbacks, + scheduling_type=async_mode, + num_threads=kwargs.get("num_threads", 4), + ) + output = pd.DataFrame([item for row in results for item in row or []]) + return TableContainer(table=output) + + +def load_strategy(strategy_type: ExtractClaimsStrategyType) -> CovariateExtractStrategy: + """Load strategy method definition.""" + match strategy_type: + case ExtractClaimsStrategyType.graph_intelligence: + from .strategies.graph_intelligence import run as run_gi + + return run_gi + case _: + msg = f"Unknown strategy: {strategy_type}" + raise ValueError(msg) + + +def create_row_from_claim_data(row, covariate_data: Covariate, covariate_type: str): + """Create a row from the claim data and the input row.""" + item = {**row, **asdict(covariate_data), "covariate_type": covariate_type} + # TODO: doc_id from extraction isn't necessary + # since chunking happens before this + del item["doc_id"] + return item diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/__init__.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/__init__.py new file mode 100644 index 0000000000..605c66f8d1 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text extract claims strategies package root.""" diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/__init__.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/__init__.py new file mode 100644 index 0000000000..ab01f06fc4 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text extract claims strategies graph intelligence package root.""" + +from .run_gi_extract_claims import run + +__all__ = ["run"] diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/defaults.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/defaults.py new file mode 100644 index 0000000000..846bfa81e0 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/defaults.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing MOCK_LLM_RESPONSES definition.""" + +MOCK_LLM_RESPONSES = [ + """ +[ + { + "subject": "COMPANY A", + "object": "GOVERNMENT AGENCY B", + "type": "ANTI-COMPETITIVE PRACTICES", + "status": "TRUE", + "start_date": "2022-01-10T00:00:00", + "end_date": "2022-01-10T00:00:00", + "description": "Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10", + "source_text": ["According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B."] + } +] + """.strip() +] diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/run_gi_extract_claims.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/run_gi_extract_claims.py new file mode 100644 index 0000000000..1c9f058830 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/run_gi_extract_claims.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _run_chain methods definitions.""" + +from collections.abc import Iterable +from typing import Any + +from datashaper import VerbCallbacks + +import graphrag.config.defaults as defs +from graphrag.config.enums import LLMType +from graphrag.index.cache import PipelineCache +from graphrag.index.graph.extractors.claims import ClaimExtractor +from graphrag.index.llm import load_llm +from graphrag.index.verbs.covariates.typing import ( + Covariate, + CovariateExtractionResult, +) +from graphrag.llm import CompletionLLM + +from .defaults import MOCK_LLM_RESPONSES + + +async def run( + input: str | Iterable[str], + entity_types: list[str], + resolved_entities_map: dict[str, str], + reporter: VerbCallbacks, + pipeline_cache: PipelineCache, + strategy_config: dict[str, Any], +) -> CovariateExtractionResult: + """Run the Claim extraction chain.""" + llm_config = strategy_config.get( + "llm", {"type": LLMType.StaticResponse, "responses": MOCK_LLM_RESPONSES} + ) + llm_type = llm_config.get("type", LLMType.StaticResponse) + llm = load_llm("claim_extraction", llm_type, reporter, pipeline_cache, llm_config) + return await _execute( + llm, input, entity_types, resolved_entities_map, reporter, strategy_config + ) + + +async def _execute( + llm: CompletionLLM, + texts: Iterable[str], + entity_types: list[str], + resolved_entities_map: dict[str, str], + reporter: VerbCallbacks, + strategy_config: dict[str, Any], +) -> CovariateExtractionResult: + extraction_prompt = strategy_config.get("extraction_prompt") + max_gleanings = strategy_config.get("max_gleanings", defs.CLAIM_MAX_GLEANINGS) + tuple_delimiter = strategy_config.get("tuple_delimiter") + record_delimiter = strategy_config.get("record_delimiter") + completion_delimiter = strategy_config.get("completion_delimiter") + encoding_model = strategy_config.get("encoding_name") + + extractor = ClaimExtractor( + llm_invoker=llm, + extraction_prompt=extraction_prompt, + max_gleanings=max_gleanings, + encoding_model=encoding_model, + on_error=lambda e, s, d: ( + reporter.error("Claim Extraction Error", e, s, d) if reporter else None + ), + ) + + claim_description = strategy_config.get("claim_description") + if claim_description is None: + msg = "claim_description is required for claim extraction" + raise ValueError(msg) + + texts = [texts] if isinstance(texts, str) else texts + + results = await extractor({ + "input_text": texts, + "entity_specs": entity_types, + "resolved_entities": resolved_entities_map, + "claim_description": claim_description, + "tuple_delimiter": tuple_delimiter, + "record_delimiter": record_delimiter, + "completion_delimiter": completion_delimiter, + }) + + claim_data = results.output + return CovariateExtractionResult([create_covariate(item) for item in claim_data]) + + +def create_covariate(item: dict[str, Any]) -> Covariate: + """Create a covariate from the item.""" + return Covariate( + subject_id=item.get("subject_id"), + subject_type=item.get("subject_type"), + object_id=item.get("object_id"), + object_type=item.get("object_type"), + type=item.get("type"), + status=item.get("status"), + start_date=item.get("start_date"), + end_date=item.get("end_date"), + description=item.get("description"), + source_text=item.get("source_text"), + doc_id=item.get("doc_id"), + record_id=item.get("record_id"), + id=item.get("id"), + ) diff --git a/func-app/graphrag/index/verbs/covariates/typing.py b/func-app/graphrag/index/verbs/covariates/typing.py new file mode 100644 index 0000000000..e31cfa4989 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/typing.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Covariate' and 'CovariateExtractionResult' models.""" + +from collections.abc import Awaitable, Callable, Iterable +from dataclasses import dataclass +from typing import Any + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + + +@dataclass +class Covariate: + """Covariate class definition.""" + + covariate_type: str | None = None + subject_id: str | None = None + subject_type: str | None = None + object_id: str | None = None + object_type: str | None = None + type: str | None = None + status: str | None = None + start_date: str | None = None + end_date: str | None = None + description: str | None = None + source_text: list[str] | None = None + doc_id: str | None = None + record_id: int | None = None + id: str | None = None + + +@dataclass +class CovariateExtractionResult: + """Covariate extraction result class definition.""" + + covariate_data: list[Covariate] + + +CovariateExtractStrategy = Callable[ + [ + Iterable[str], + list[str], + dict[str, str], + VerbCallbacks, + PipelineCache, + dict[str, Any], + ], + Awaitable[CovariateExtractionResult], +] diff --git a/func-app/graphrag/index/verbs/entities/__init__.py b/func-app/graphrag/index/verbs/entities/__init__.py new file mode 100644 index 0000000000..2f55d710e9 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine entities package root.""" + +from .extraction import entity_extract +from .summarize import summarize_descriptions + +__all__ = ["entity_extract", "summarize_descriptions"] diff --git a/func-app/graphrag/index/verbs/entities/extraction/__init__.py b/func-app/graphrag/index/verbs/entities/extraction/__init__.py new file mode 100644 index 0000000000..46e6d54581 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine entities extraction package root.""" + +from .entity_extract import ExtractEntityStrategyType, entity_extract + +__all__ = ["ExtractEntityStrategyType", "entity_extract"] diff --git a/func-app/graphrag/index/verbs/entities/extraction/entity_extract.py b/func-app/graphrag/index/verbs/entities/extraction/entity_extract.py new file mode 100644 index 0000000000..4e961f674d --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/entity_extract.py @@ -0,0 +1,202 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing entity_extract methods.""" + +import logging +from enum import Enum +from typing import Any, cast + +import pandas as pd +from datashaper import ( + AsyncType, + TableContainer, + VerbCallbacks, + VerbInput, + derive_from_rows, + verb, +) + +from graphrag.index.bootstrap import bootstrap +from graphrag.index.cache import PipelineCache + +from .strategies.typing import Document, EntityExtractStrategy + +log = logging.getLogger(__name__) + + +class ExtractEntityStrategyType(str, Enum): + """ExtractEntityStrategyType class definition.""" + + graph_intelligence = "graph_intelligence" + graph_intelligence_json = "graph_intelligence_json" + nltk = "nltk" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] + + +@verb(name="entity_extract") +async def entity_extract( + input: VerbInput, + cache: PipelineCache, + callbacks: VerbCallbacks, + column: str, + id_column: str, + to: str, + strategy: dict[str, Any] | None, + graph_to: str | None = None, + async_mode: AsyncType = AsyncType.AsyncIO, + entity_types=DEFAULT_ENTITY_TYPES, + **kwargs, +) -> TableContainer: + """ + Extract entities from a piece of text. + + ## Usage + ### json + ```json + { + "verb": "entity_extract", + "args": { + "column": "the_document_text_column_to_extract_entities_from", /* In general this will be your document text column */ + "id_column": "the_column_with_the_unique_id_for_each_row", /* In general this will be your document id */ + "to": "the_column_to_output_the_entities_to", /* This will be a list[dict[str, Any]] a list of entities, with a name, and additional attributes */ + "graph_to": "the_column_to_output_the_graphml_to", /* Optional: This will be a graphml graph in string form which represents the entities and their relationships */ + "strategy": {...} , see strategies section below + "entity_types": ["list", "of", "entity", "types", "to", "extract"] /* Optional: This will limit the entity types extracted, default: ["organization", "person", "geo", "event"] */ + "summarize_descriptions" : true | false /* Optional: This will summarize the descriptions of the entities and relationships, default: true */ + } + } + ``` + ### yaml + ```yaml + verb: entity_extract + args: + column: the_document_text_column_to_extract_entities_from + id_column: the_column_with_the_unique_id_for_each_row + to: the_column_to_output_the_entities_to + graph_to: the_column_to_output_the_graphml_to + strategy: , see strategies section below + summarize_descriptions: true | false /* Optional: This will summarize the descriptions of the entities and relationships, default: true */ + entity_types: + - list + - of + - entity + - types + - to + - extract + ``` + + ## Strategies + The entity extract verb uses a strategy to extract entities from a document. The strategy is a json object which defines the strategy to use. The following strategies are available: + + ### graph_intelligence + This strategy uses the [graph_intelligence] library to extract entities from a document. In particular it uses a LLM to extract entities from a piece of text. The strategy config is as follows: + + ```yml + strategy: + type: graph_intelligence + extraction_prompt: !include ./entity_extraction_prompt.txt # Optional, the prompt to use for extraction + completion_delimiter: "<|COMPLETE|>" # Optional, the delimiter to use for the LLM to mark completion + tuple_delimiter: "<|>" # Optional, the delimiter to use for the LLM to mark a tuple + record_delimiter: "##" # Optional, the delimiter to use for the LLM to mark a record + + prechunked: true | false # Optional, If the document is already chunked beforehand, otherwise this will chunk the document into smaller bits. default: false + encoding_name: cl100k_base # Optional, The encoding to use for the LLM, if not already prechunked, default: cl100k_base + chunk_size: 1000 # Optional ,The chunk size to use for the LLM, if not already prechunked, default: 1200 + chunk_overlap: 100 # Optional, The chunk overlap to use for the LLM, if not already prechunked, default: 100 + + llm: # The configuration for the LLM + type: openai # the type of llm to use, available options are: openai, azure, openai_chat, azure_openai_chat. The last two being chat based LLMs. + api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai + model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai + max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai + organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai + + # if using azure flavor + api_base: !ENV ${GRAPHRAG_OPENAI_API_BASE} # The api base to use for azure + api_version: !ENV ${GRAPHRAG_OPENAI_API_VERSION} # The api version to use for azure + proxy: !ENV ${GRAPHRAG_OPENAI_PROXY} # The proxy to use for azure + + ``` + + ### nltk + This strategy uses the [nltk] library to extract entities from a document. In particular it uses a nltk to extract entities from a piece of text. The strategy config is as follows: + ```yml + strategy: + type: nltk + ``` + """ + log.debug("entity_extract strategy=%s", strategy) + if entity_types is None: + entity_types = DEFAULT_ENTITY_TYPES + output = cast(pd.DataFrame, input.get_input()) + strategy = strategy or {} + strategy_exec = _load_strategy( + strategy.get("type", ExtractEntityStrategyType.graph_intelligence) + ) + strategy_config = {**strategy} + + num_started = 0 + + async def run_strategy(row): + nonlocal num_started + text = row[column] + id = row[id_column] + result = await strategy_exec( + [Document(text=text, id=id)], + entity_types, + callbacks, + cache, + strategy_config, + ) + num_started += 1 + return [result.entities, result.graphml_graph] + + results = await derive_from_rows( + output, + run_strategy, + callbacks, + scheduling_type=async_mode, + num_threads=kwargs.get("num_threads", 4), + ) + + to_result = [] + graph_to_result = [] + for result in results: + if result: + to_result.append(result[0]) + graph_to_result.append(result[1]) + else: + to_result.append(None) + graph_to_result.append(None) + + output[to] = to_result + if graph_to is not None: + output[graph_to] = graph_to_result + + return TableContainer(table=output.reset_index(drop=True)) + + +def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy: + """Load strategy method definition.""" + match strategy_type: + case ExtractEntityStrategyType.graph_intelligence: + from .strategies.graph_intelligence import run_gi + + return run_gi + + case ExtractEntityStrategyType.nltk: + bootstrap() + # dynamically import nltk strategy to avoid dependency if not used + from .strategies.nltk import run as run_nltk + + return run_nltk + case _: + msg = f"Unknown strategy: {strategy_type}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/__init__.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/__init__.py new file mode 100644 index 0000000000..f5cc17d750 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine entities extraction strategies package root.""" diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/__init__.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/__init__.py new file mode 100644 index 0000000000..083c0e4112 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph intelligence package root.""" + +from .run_graph_intelligence import run_gi + +__all__ = ["run_gi"] diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/defaults.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/defaults.py new file mode 100644 index 0000000000..237e6657c8 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/defaults.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing some default responses.""" + +from graphrag.config.enums import LLMType + +MOCK_LLM_RESPONSES = [ + """ + ("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company) + ## + ("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A) + ## + ("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A) + ## + ("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2) + ## + ("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1)) + """.strip() +] + +DEFAULT_LLM_CONFIG = { + "type": LLMType.StaticResponse, + "responses": MOCK_LLM_RESPONSES, +} diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/run_graph_intelligence.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/run_graph_intelligence.py new file mode 100644 index 0000000000..0628487983 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/run_graph_intelligence.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_gi, run_extract_entities and _create_text_splitter methods to run graph intelligence.""" + +import networkx as nx +from datashaper import VerbCallbacks + +import graphrag.config.defaults as defs +from graphrag.config.enums import LLMType +from graphrag.index.cache import PipelineCache +from graphrag.index.graph.extractors.graph import GraphExtractor +from graphrag.index.llm import load_llm +from graphrag.index.text_splitting import ( + NoopTextSplitter, + TextSplitter, + TokenTextSplitter, +) +from graphrag.index.verbs.entities.extraction.strategies.typing import ( + Document, + EntityExtractionResult, + EntityTypes, + StrategyConfig, +) +from graphrag.llm import CompletionLLM + +from .defaults import DEFAULT_LLM_CONFIG + + +async def run_gi( + docs: list[Document], + entity_types: EntityTypes, + reporter: VerbCallbacks, + pipeline_cache: PipelineCache, + args: StrategyConfig, +) -> EntityExtractionResult: + """Run the graph intelligence entity extraction strategy.""" + llm_config = args.get("llm", DEFAULT_LLM_CONFIG) + llm_type = llm_config.get("type", LLMType.StaticResponse) + llm = load_llm("entity_extraction", llm_type, reporter, pipeline_cache, llm_config) + return await run_extract_entities(llm, docs, entity_types, reporter, args) + + +async def run_extract_entities( + llm: CompletionLLM, + docs: list[Document], + entity_types: EntityTypes, + reporter: VerbCallbacks | None, + args: StrategyConfig, +) -> EntityExtractionResult: + """Run the entity extraction chain.""" + encoding_name = args.get("encoding_name", "cl100k_base") + + # Chunking Arguments + prechunked = args.get("prechunked", False) + chunk_size = args.get("chunk_size", defs.CHUNK_SIZE) + chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP) + + # Extraction Arguments + tuple_delimiter = args.get("tuple_delimiter", None) + record_delimiter = args.get("record_delimiter", None) + completion_delimiter = args.get("completion_delimiter", None) + extraction_prompt = args.get("extraction_prompt", None) + encoding_model = args.get("encoding_name", None) + max_gleanings = args.get("max_gleanings", defs.ENTITY_EXTRACTION_MAX_GLEANINGS) + + # note: We're not using UnipartiteGraphChain.from_params + # because we want to pass "timeout" to the llm_kwargs + text_splitter = _create_text_splitter( + prechunked, chunk_size, chunk_overlap, encoding_name + ) + + extractor = GraphExtractor( + llm_invoker=llm, + prompt=extraction_prompt, + encoding_model=encoding_model, + max_gleanings=max_gleanings, + on_error=lambda e, s, d: ( + reporter.error("Entity Extraction Error", e, s, d) if reporter else None + ), + ) + text_list = [doc.text.strip() for doc in docs] + + # If it's not pre-chunked, then re-chunk the input + if not prechunked: + text_list = text_splitter.split_text("\n".join(text_list)) + + results = await extractor( + list(text_list), + { + "entity_types": entity_types, + "tuple_delimiter": tuple_delimiter, + "record_delimiter": record_delimiter, + "completion_delimiter": completion_delimiter, + }, + ) + + graph = results.output + # Map the "source_id" back to the "id" field + for _, node in graph.nodes(data=True): # type: ignore + if node is not None: + node["source_id"] = ",".join( + docs[int(id)].id for id in node["source_id"].split(",") + ) + + for _, _, edge in graph.edges(data=True): # type: ignore + if edge is not None: + edge["source_id"] = ",".join( + docs[int(id)].id for id in edge["source_id"].split(",") + ) + + entities = [ + ({"name": item[0], **(item[1] or {})}) + for item in graph.nodes(data=True) + if item is not None + ] + + graph_data = "".join(nx.generate_graphml(graph)) + return EntityExtractionResult(entities, graph_data) + + +def _create_text_splitter( + prechunked: bool, chunk_size: int, chunk_overlap: int, encoding_name: str +) -> TextSplitter: + """Create a text splitter for the extraction chain. + + Args: + - prechunked - Whether the text is already chunked + - chunk_size - The size of each chunk + - chunk_overlap - The overlap between chunks + - encoding_name - The name of the encoding to use + Returns: + - output - A text splitter + """ + if prechunked: + return NoopTextSplitter() + + return TokenTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + encoding_name=encoding_name, + ) diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/nltk.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/nltk.py new file mode 100644 index 0000000000..48d4dae4ca --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/nltk.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run method definition.""" + +import networkx as nx +import nltk +from datashaper import VerbCallbacks +from nltk.corpus import words + +from graphrag.index.cache import PipelineCache + +from .typing import Document, EntityExtractionResult, EntityTypes, StrategyConfig + +# Need to do this cause we're potentially multithreading, and nltk doesn't like that +words.ensure_loaded() + + +async def run( # noqa RUF029 async is required for interface + docs: list[Document], + entity_types: EntityTypes, + reporter: VerbCallbacks, # noqa ARG001 + pipeline_cache: PipelineCache, # noqa ARG001 + args: StrategyConfig, # noqa ARG001 +) -> EntityExtractionResult: + """Run method definition.""" + entity_map = {} + graph = nx.Graph() + for doc in docs: + connected_entities = [] + for chunk in nltk.ne_chunk(nltk.pos_tag(nltk.word_tokenize(doc.text))): + if hasattr(chunk, "label"): + entity_type = chunk.label().lower() + if entity_type in entity_types: + name = (" ".join(c[0] for c in chunk)).upper() + connected_entities.append(name) + if name not in entity_map: + entity_map[name] = entity_type + graph.add_node( + name, type=entity_type, description=name, source_id=doc.id + ) + + # connect the entities if they appear in the same document + if len(connected_entities) > 1: + for i in range(len(connected_entities)): + for j in range(i + 1, len(connected_entities)): + description = f"{connected_entities[i]} -> {connected_entities[j]}" + graph.add_edge( + connected_entities[i], + connected_entities[j], + description=description, + source_id=doc.id, + ) + + return EntityExtractionResult( + entities=[ + {"type": entity_type, "name": name} + for name, entity_type in entity_map.items() + ], + graphml_graph="".join(nx.generate_graphml(graph)), + ) diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/typing.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/typing.py new file mode 100644 index 0000000000..45d3f1b80e --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/typing.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Document' and 'EntityExtractionResult' models.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + +ExtractedEntity = dict[str, Any] +StrategyConfig = dict[str, Any] +EntityTypes = list[str] + + +@dataclass +class Document: + """Document class definition.""" + + text: str + id: str + + +@dataclass +class EntityExtractionResult: + """Entity extraction result class definition.""" + + entities: list[ExtractedEntity] + graphml_graph: str | None + + +EntityExtractStrategy = Callable[ + [ + list[Document], + EntityTypes, + VerbCallbacks, + PipelineCache, + StrategyConfig, + ], + Awaitable[EntityExtractionResult], +] diff --git a/func-app/graphrag/index/verbs/entities/summarize/__init__.py b/func-app/graphrag/index/verbs/entities/summarize/__init__.py new file mode 100644 index 0000000000..d7e9a5d93a --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Root package for entity summarization.""" + +from .description_summarize import SummarizeStrategyType, summarize_descriptions + +__all__ = ["SummarizeStrategyType", "summarize_descriptions"] diff --git a/func-app/graphrag/index/verbs/entities/summarize/description_summarize.py b/func-app/graphrag/index/verbs/entities/summarize/description_summarize.py new file mode 100644 index 0000000000..5b7feb4184 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/description_summarize.py @@ -0,0 +1,207 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the summarize_descriptions verb.""" + +import asyncio +import logging +from enum import Enum +from typing import Any, NamedTuple, cast + +import networkx as nx +import pandas as pd +from datashaper import ( + ProgressTicker, + TableContainer, + VerbCallbacks, + VerbInput, + progress_ticker, + verb, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.utils import load_graph + +from .strategies.typing import SummarizationStrategy + +log = logging.getLogger(__name__) + + +class DescriptionSummarizeRow(NamedTuple): + """DescriptionSummarizeRow class definition.""" + + graph: Any + + +class SummarizeStrategyType(str, Enum): + """SummarizeStrategyType class definition.""" + + graph_intelligence = "graph_intelligence" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="summarize_descriptions") +async def summarize_descriptions( + input: VerbInput, + cache: PipelineCache, + callbacks: VerbCallbacks, + column: str, + to: str, + strategy: dict[str, Any] | None = None, + **kwargs, +) -> TableContainer: + """ + Summarize entity and relationship descriptions from an entity graph. + + ## Usage + + To turn this feature ON please set the environment variable `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_ENABLED=True`. + + ### json + + ```json + { + "verb": "", + "args": { + "column": "the_document_text_column_to_extract_descriptions_from", /* Required: This will be a graphml graph in string form which represents the entities and their relationships */ + "to": "the_column_to_output_the_summarized_descriptions_to", /* Required: This will be a graphml graph in string form which represents the entities and their relationships after being summarized */ + "strategy": {...} , see strategies section below + } + } + ``` + + ### yaml + + ```yaml + verb: entity_extract + args: + column: the_document_text_column_to_extract_descriptions_from + to: the_column_to_output_the_summarized_descriptions_to + strategy: , see strategies section below + ``` + + ## Strategies + + The summarize descriptions verb uses a strategy to summarize descriptions for entities. The strategy is a json object which defines the strategy to use. The following strategies are available: + + ### graph_intelligence + + This strategy uses the [graph_intelligence] library to summarize descriptions for entities. The strategy config is as follows: + + ```yml + strategy: + type: graph_intelligence + summarize_prompt: # Optional, the prompt to use for extraction + + + llm: # The configuration for the LLM + type: openai # the type of llm to use, available options are: openai, azure, openai_chat, azure_openai_chat. The last two being chat based LLMs. + api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai + model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai + max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai + organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai + + # if using azure flavor + api_base: !ENV ${GRAPHRAG_OPENAI_API_BASE} # The api base to use for azure + api_version: !ENV ${GRAPHRAG_OPENAI_API_VERSION} # The api version to use for azure + proxy: !ENV ${GRAPHRAG_OPENAI_PROXY} # The proxy to use for azure + ``` + """ + log.debug("summarize_descriptions strategy=%s", strategy) + output = cast(pd.DataFrame, input.get_input()) + strategy = strategy or {} + strategy_exec = load_strategy( + strategy.get("type", SummarizeStrategyType.graph_intelligence) + ) + strategy_config = {**strategy} + + async def get_resolved_entities(row, semaphore: asyncio.Semaphore): + graph: nx.Graph = load_graph(cast(str | nx.Graph, getattr(row, column))) + + ticker_length = len(graph.nodes) + len(graph.edges) + + ticker = progress_ticker(callbacks.progress, ticker_length) + + futures = [ + do_summarize_descriptions( + node, + sorted(set(graph.nodes[node].get("description", "").split("\n"))), + ticker, + semaphore, + ) + for node in graph.nodes() + ] + futures += [ + do_summarize_descriptions( + edge, + sorted(set(graph.edges[edge].get("description", "").split("\n"))), + ticker, + semaphore, + ) + for edge in graph.edges() + ] + + results = await asyncio.gather(*futures) + + for result in results: + graph_item = result.items + if isinstance(graph_item, str) and graph_item in graph.nodes(): + graph.nodes[graph_item]["description"] = result.description + elif isinstance(graph_item, tuple) and graph_item in graph.edges(): + graph.edges[graph_item]["description"] = result.description + + return DescriptionSummarizeRow( + graph="\n".join(nx.generate_graphml(graph)), + ) + + async def do_summarize_descriptions( + graph_item: str | tuple[str, str], + descriptions: list[str], + ticker: ProgressTicker, + semaphore: asyncio.Semaphore, + ): + async with semaphore: + results = await strategy_exec( + graph_item, + descriptions, + callbacks, + cache, + strategy_config, + ) + ticker(1) + return results + + # Graph is always on row 0, so here a derive from rows does not work + # This iteration will only happen once, but avoids hardcoding a iloc[0] + # Since parallelization is at graph level (nodes and edges), we can't use + # the parallelization of the derive_from_rows + semaphore = asyncio.Semaphore(kwargs.get("num_threads", 4)) + + results = [ + await get_resolved_entities(row, semaphore) for row in output.itertuples() + ] + + to_result = [] + + for result in results: + if result: + to_result.append(result.graph) + else: + to_result.append(None) + output[to] = to_result + return TableContainer(table=output) + + +def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy: + """Load strategy method definition.""" + match strategy_type: + case SummarizeStrategyType.graph_intelligence: + from .strategies.graph_intelligence import run as run_gi + + return run_gi + case _: + msg = f"Unknown strategy: {strategy_type}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/entities/summarize/strategies/__init__.py b/func-app/graphrag/index/verbs/entities/summarize/strategies/__init__.py new file mode 100644 index 0000000000..28c398e6ac --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/strategies/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Indexing Engine - Summarization Strategies Package.""" + +from .typing import SummarizationStrategy + +__all__ = ["SummarizationStrategy"] diff --git a/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/__init__.py b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/__init__.py new file mode 100644 index 0000000000..a98d9406cb --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Entity summarization graph intelligence package root.""" + +from .run_graph_intelligence import run + +__all__ = ["run"] diff --git a/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/defaults.py b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/defaults.py new file mode 100644 index 0000000000..8ac42aa13d --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/defaults.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing some default responses.""" + +from graphrag.config.enums import LLMType + +MOCK_LLM_RESPONSES = [ + """ + This is a MOCK response for the LLM. It is summarized! + """.strip() +] + +DEFAULT_LLM_CONFIG = { + "type": LLMType.StaticResponse, + "responses": MOCK_LLM_RESPONSES, +} diff --git a/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/run_graph_intelligence.py b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/run_graph_intelligence.py new file mode 100644 index 0000000000..57a1ecd218 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/run_graph_intelligence.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_gi, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence.""" + +from datashaper import VerbCallbacks + +from graphrag.config.enums import LLMType +from graphrag.index.cache import PipelineCache +from graphrag.index.graph.extractors.summarize import SummarizeExtractor +from graphrag.index.llm import load_llm +from graphrag.index.verbs.entities.summarize.strategies.typing import ( + StrategyConfig, + SummarizedDescriptionResult, +) +from graphrag.llm import CompletionLLM + +from .defaults import DEFAULT_LLM_CONFIG + + +async def run( + described_items: str | tuple[str, str], + descriptions: list[str], + reporter: VerbCallbacks, + pipeline_cache: PipelineCache, + args: StrategyConfig, +) -> SummarizedDescriptionResult: + """Run the graph intelligence entity extraction strategy.""" + llm_config = args.get("llm", DEFAULT_LLM_CONFIG) + llm_type = llm_config.get("type", LLMType.StaticResponse) + llm = load_llm( + "summarize_descriptions", llm_type, reporter, pipeline_cache, llm_config + ) + return await run_summarize_descriptions( + llm, described_items, descriptions, reporter, args + ) + + +async def run_summarize_descriptions( + llm: CompletionLLM, + items: str | tuple[str, str], + descriptions: list[str], + reporter: VerbCallbacks, + args: StrategyConfig, +) -> SummarizedDescriptionResult: + """Run the entity extraction chain.""" + # Extraction Arguments + summarize_prompt = args.get("summarize_prompt", None) + entity_name_key = args.get("entity_name_key", "entity_name") + input_descriptions_key = args.get("input_descriptions_key", "description_list") + max_tokens = args.get("max_tokens", None) + + extractor = SummarizeExtractor( + llm_invoker=llm, + summarization_prompt=summarize_prompt, + entity_name_key=entity_name_key, + input_descriptions_key=input_descriptions_key, + on_error=lambda e, stack, details: ( + reporter.error("Entity Extraction Error", e, stack, details) + if reporter + else None + ), + max_summary_length=args.get("max_summary_length", None), + max_input_tokens=max_tokens, + ) + + result = await extractor(items=items, descriptions=descriptions) + return SummarizedDescriptionResult( + items=result.items, description=result.description + ) diff --git a/func-app/graphrag/index/verbs/entities/summarize/strategies/typing.py b/func-app/graphrag/index/verbs/entities/summarize/strategies/typing.py new file mode 100644 index 0000000000..398295031b --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/strategies/typing.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'SummarizedDescriptionResult' model.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + +StrategyConfig = dict[str, Any] + + +@dataclass +class SummarizedDescriptionResult: + """Entity summarization result class definition.""" + + items: str | tuple[str, str] + description: str + + +SummarizationStrategy = Callable[ + [ + str | tuple[str, str], + list[str], + VerbCallbacks, + PipelineCache, + StrategyConfig, + ], + Awaitable[SummarizedDescriptionResult], +] diff --git a/func-app/graphrag/index/verbs/genid.py b/func-app/graphrag/index/verbs/genid.py new file mode 100644 index 0000000000..019ffc2da0 --- /dev/null +++ b/func-app/graphrag/index/verbs/genid.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing genid method definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.utils import gen_md5_hash + + +@verb(name="genid") +def genid( + input: VerbInput, + to: str, + method: str = "md5_hash", + hash: list[str] = [], # noqa A002 + **_kwargs: dict, +) -> TableContainer: + """ + Generate a unique id for each row in the tabular data. + + ## Usage + ### json + ```json + { + "verb": "genid", + "args": { + "to": "id_output_column_name", /* The name of the column to output the id to */ + "method": "md5_hash", /* The method to use to generate the id */ + "hash": ["list", "of", "column", "names"] /* only if using md5_hash */, + "seed": 034324 /* The random seed to use with UUID */ + } + } + ``` + + ### yaml + ```yaml + verb: genid + args: + to: id_output_column_name + method: md5_hash + hash: + - list + - of + - column + - names + seed: 034324 + ``` + """ + data = cast(pd.DataFrame, input.source.table) + + if method == "md5_hash": + if len(hash) == 0: + msg = 'Must specify the "hash" columns to use md5_hash method' + raise ValueError(msg) + + data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1) + elif method == "increment": + data[to] = data.index + 1 + else: + msg = f"Unknown method {method}" + raise ValueError(msg) + return TableContainer(table=data) diff --git a/func-app/graphrag/index/verbs/graph/__init__.py b/func-app/graphrag/index/verbs/graph/__init__.py new file mode 100644 index 0000000000..5edbdbe530 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph package root.""" + +from .clustering import cluster_graph +from .compute_edge_combined_degree import compute_edge_combined_degree +from .create import DEFAULT_EDGE_ATTRIBUTES, DEFAULT_NODE_ATTRIBUTES, create_graph +from .embed import embed_graph +from .layout import layout_graph +from .merge import merge_graphs +from .report import ( + create_community_reports, + prepare_community_reports, + prepare_community_reports_claims, + prepare_community_reports_edges, + restore_community_hierarchy, +) +from .unpack import unpack_graph + +__all__ = [ + "DEFAULT_EDGE_ATTRIBUTES", + "DEFAULT_NODE_ATTRIBUTES", + "cluster_graph", + "compute_edge_combined_degree", + "create_community_reports", + "create_graph", + "embed_graph", + "layout_graph", + "merge_graphs", + "prepare_community_reports", + "prepare_community_reports_claims", + "prepare_community_reports_edges", + "restore_community_hierarchy", + "unpack_graph", +] diff --git a/func-app/graphrag/index/verbs/graph/clustering/__init__.py b/func-app/graphrag/index/verbs/graph/clustering/__init__.py new file mode 100644 index 0000000000..a5db89bb7f --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/clustering/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph clustering package root.""" + +from .cluster_graph import GraphCommunityStrategyType, cluster_graph + +__all__ = ["GraphCommunityStrategyType", "cluster_graph"] diff --git a/func-app/graphrag/index/verbs/graph/clustering/cluster_graph.py b/func-app/graphrag/index/verbs/graph/clustering/cluster_graph.py new file mode 100644 index 0000000000..0cfb929c63 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/clustering/cluster_graph.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing cluster_graph, apply_clustering and run_layout methods definition.""" + +import logging +from enum import Enum +from random import Random +from typing import Any, cast + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb + +from graphrag.index.utils import gen_uuid, load_graph + +from .typing import Communities +from hashlib import sha256 + +log = logging.getLogger(__name__) + + +@verb(name="cluster_graph") +def cluster_graph( + input: VerbInput, + callbacks: VerbCallbacks, + strategy: dict[str, Any], + column: str, + to: str, + level_to: str | None = None, + **_kwargs, +) -> TableContainer: + """ + Apply a hierarchical clustering algorithm to a graph. The graph is expected to be in graphml format. The verb outputs a new column containing the clustered graph, and a new column containing the level of the graph. + + ## Usage + ```yaml + verb: cluster_graph + args: + column: entity_graph # The name of the column containing the graph, should be a graphml graph + to: clustered_graph # The name of the column to output the clustered graph to + level_to: level # The name of the column to output the level to + strategy: # See strategies section below + ``` + + ## Strategies + The cluster graph verb uses a strategy to cluster the graph. The strategy is a json object which defines the strategy to use. The following strategies are available: + + ### leiden + This strategy uses the leiden algorithm to cluster a graph. The strategy config is as follows: + ```yaml + strategy: + type: leiden + max_cluster_size: 10 # Optional, The max cluster size to use, default: 10 + use_lcc: true # Optional, if the largest connected component should be used with the leiden algorithm, default: true + seed: 0xDEADBEEF # Optional, the seed to use for the leiden algorithm, default: 0xDEADBEEF + levels: [0, 1] # Optional, the levels to output, default: all the levels detected + + ``` + """ + output_df = cast(pd.DataFrame, input.get_input()) + results = output_df[column].apply(lambda graph: run_layout(strategy, graph)) + + community_map_to = "communities" + output_df[community_map_to] = results + + level_to = level_to or f"{to}_level" + output_df[level_to] = output_df.apply( + lambda x: list({level for level, _, _ in x[community_map_to]}), axis=1 + ) + output_df[to] = [None] * len(output_df) + + num_total = len(output_df) + + # Go through each of the rows + graph_level_pairs_column: list[list[tuple[int, str]]] = [] + for _, row in progress_iterable( + output_df.iterrows(), callbacks.progress, num_total + ): + levels = row[level_to] + graph_level_pairs: list[tuple[int, str]] = [] + + # For each of the levels, get the graph and add it to the list + for level in levels: + graph = "\n".join( + nx.generate_graphml( + apply_clustering( + cast(str, row[column]), + cast(Communities, row[community_map_to]), + level, + ) + ) + ) + graph_level_pairs.append((level, graph)) + graph_level_pairs_column.append(graph_level_pairs) + output_df[to] = graph_level_pairs_column + + # explode the list of (level, graph) pairs into separate rows + output_df = output_df.explode(to, ignore_index=True) + + # split the (level, graph) pairs into separate columns + # TODO: There is probably a better way to do this + output_df[[level_to, to]] = pd.DataFrame( + output_df[to].tolist(), index=output_df.index + ) + + # clean up the community map + output_df.drop(columns=[community_map_to], inplace=True) + + return TableContainer(table=output_df) + +def generate_entity_id(candidate: str) -> str: + h=sha256() + h.update(candidate.encode()) + return h.hexdigest() + +# TODO: This should support str | nx.Graph as a graphml param +def apply_clustering( + graphml: str, communities: Communities, level=0 +) -> nx.Graph: + """Apply clustering to a graphml string.""" + + graph = nx.parse_graphml(graphml) + for community_level, community_id, nodes in communities: + if level == community_level: + for node in nodes: + graph.nodes[node]["cluster"] = community_id + graph.nodes[node]["level"] = level + + # add node degree + for node_degree in graph.degree: + graph.nodes[str(node_degree[0])]["degree"] = int(node_degree[1]) + + # Generate a unique ID for each entitiy and incremental record id (a human readable id used as reference in the final report) + for index, node in enumerate(graph.nodes()): + graph.nodes[node]["human_readable_id"] = index + graph.nodes[node]["id"] = generate_entity_id(node) + + # add ids to edges + for index, edge in enumerate(graph.edges()): + graph.edges[edge]["human_readable_id"] = index + graph.edges[edge]["level"] = level + graph.edges[edge]["id"] = generate_entity_id(f"{edge[0]}:{edge[1]}") + + return graph + + +class GraphCommunityStrategyType(str, Enum): + """GraphCommunityStrategyType class definition.""" + + leiden = "leiden" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +def run_layout( + strategy: dict[str, Any], graphml_or_graph: str | nx.Graph +) -> Communities: + """Run layout method definition.""" + graph = load_graph(graphml_or_graph) + if len(graph.nodes) == 0: + log.warning("Graph has no nodes") + return [] + + clusters: dict[int, dict[str, list[str]]] = {} + strategy_type = strategy.get("type", GraphCommunityStrategyType.leiden) + match strategy_type: + case GraphCommunityStrategyType.leiden: + from .strategies.leiden import run as run_leiden + + clusters = run_leiden(graph, strategy) + case _: + msg = f"Unknown clustering strategy {strategy_type}" + raise ValueError(msg) + + results: Communities = [] + for level in clusters: + for cluster_id, nodes in clusters[level].items(): + results.append((level, cluster_id, nodes)) + return results diff --git a/func-app/graphrag/index/verbs/graph/clustering/strategies/__init__.py b/func-app/graphrag/index/verbs/graph/clustering/strategies/__init__.py new file mode 100644 index 0000000000..16a03f12d6 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/clustering/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Graph Clustering Strategies.""" diff --git a/func-app/graphrag/index/verbs/graph/clustering/strategies/leiden.py b/func-app/graphrag/index/verbs/graph/clustering/strategies/leiden.py new file mode 100644 index 0000000000..ffc3688041 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/clustering/strategies/leiden.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _compute_leiden_communities methods definitions.""" + +import logging +from typing import Any + +import networkx as nx +from graspologic.partition import hierarchical_leiden + +from graphrag.index.graph.utils import stable_largest_connected_component + +log = logging.getLogger(__name__) + + +def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, list[str]]]: + """Run method definition.""" + max_cluster_size = args.get("max_cluster_size", 10) + use_lcc = args.get("use_lcc", True) + if args.get("verbose", False): + log.info( + "Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc + ) + + node_id_to_community_map = _compute_leiden_communities( + graph=graph, + max_cluster_size=max_cluster_size, + use_lcc=use_lcc, + seed=args.get("seed", 0xDEADBEEF), + ) + levels = args.get("levels") + + # If they don't pass in levels, use them all + if levels is None: + levels = sorted(node_id_to_community_map.keys()) + + results_by_level: dict[int, dict[str, list[str]]] = {} + for level in levels: + result = {} + results_by_level[level] = result + for node_id, raw_community_id in node_id_to_community_map[level].items(): + community_id = str(raw_community_id) + if community_id not in result: + result[community_id] = [] + result[community_id].append(node_id) + return results_by_level + + +# Taken from graph_intelligence & adapted +def _compute_leiden_communities( + graph: nx.Graph | nx.DiGraph, + max_cluster_size: int, + use_lcc: bool, + seed=0xDEADBEEF, +) -> dict[int, dict[str, int]]: + """Return Leiden root communities.""" + if use_lcc: + graph = stable_largest_connected_component(graph) + + community_mapping = hierarchical_leiden( + graph, max_cluster_size=max_cluster_size, random_seed=seed + ) + results: dict[int, dict[str, int]] = {} + for partition in community_mapping: + results[partition.level] = results.get(partition.level, {}) + results[partition.level][partition.node] = partition.cluster + + return results diff --git a/func-app/graphrag/index/verbs/graph/clustering/typing.py b/func-app/graphrag/index/verbs/graph/clustering/typing.py new file mode 100644 index 0000000000..4d6fc7e601 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/clustering/typing.py @@ -0,0 +1,6 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing Communities list definition.""" + +Communities = list[tuple[int, str, list[str]]] diff --git a/func-app/graphrag/index/verbs/graph/compute_edge_combined_degree.py b/func-app/graphrag/index/verbs/graph/compute_edge_combined_degree.py new file mode 100644 index 0000000000..1f2dd71972 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/compute_edge_combined_degree.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.utils.ds_util import get_required_input_table + + +@verb(name="compute_edge_combined_degree") +def compute_edge_combined_degree( + input: VerbInput, + to: str = "rank", + node_name_column: str = "title", + node_degree_column: str = "degree", + edge_source_column: str = "source", + edge_target_column: str = "target", + **_kwargs, +) -> TableContainer: + """ + Compute the combined degree for each edge in a graph. + + Inputs Tables: + - input: The edge table + - nodes: The nodes table. + + Args: + - to: The name of the column to output the combined degree to. Default="rank" + """ + edge_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()) + if to in edge_df.columns: + return TableContainer(table=edge_df) + node_degree_df = _get_node_degree_table(input, node_name_column, node_degree_column) + + def join_to_degree(df: pd.DataFrame, column: str) -> pd.DataFrame: + degree_column = _degree_colname(column) + result = df.merge( + node_degree_df.rename( + columns={node_name_column: column, node_degree_column: degree_column} + ), + on=column, + how="left", + ) + result[degree_column] = result[degree_column].fillna(0) + return result + + edge_df = join_to_degree(edge_df, edge_source_column) + edge_df = join_to_degree(edge_df, edge_target_column) + edge_df[to] = ( + edge_df[_degree_colname(edge_source_column)] + + edge_df[_degree_colname(edge_target_column)] + ) + + return TableContainer(table=edge_df) + + +def _degree_colname(column: str) -> str: + return f"{column}_degree" + + +def _get_node_degree_table( + input: VerbInput, node_name_column: str, node_degree_column: str +) -> pd.DataFrame: + nodes_container = get_required_input_table(input, "nodes") + nodes = cast(pd.DataFrame, nodes_container.table) + return cast(pd.DataFrame, nodes[[node_name_column, node_degree_column]]) diff --git a/func-app/graphrag/index/verbs/graph/create.py b/func-app/graphrag/index/verbs/graph/create.py new file mode 100644 index 0000000000..eaf06284ef --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/create.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +from typing import Any + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb + +from graphrag.index.utils import clean_str + +DEFAULT_NODE_ATTRIBUTES = ["label", "type", "id", "name", "description", "community"] +DEFAULT_EDGE_ATTRIBUTES = ["label", "type", "name", "source", "target"] + + +@verb(name="create_graph") +def create_graph( + input: VerbInput, + callbacks: VerbCallbacks, + to: str, + type: str, # noqa A002 + graph_type: str = "undirected", + **kwargs, +) -> TableContainer: + """ + Create a graph from a dataframe. The verb outputs a new column containing the graph. + + > Note: This will roll up all rows into a single graph. + + ## Usage + ```yaml + verb: create_graph + args: + type: node # The type of graph to create, one of: node, edge + to: # The name of the column to output the graph to, this will be a graphml graph + attributes: # The attributes for the nodes / edges + # If using the node type, the following attributes are required: + id: + + # If using the edge type, the following attributes are required: + source: + target: + + # Other attributes can be added as follows: + : + ... for each attribute + ``` + """ + if type != "node" and type != "edge": + msg = f"Unknown type {type}" + raise ValueError(msg) + + input_df = input.get_input() + num_total = len(input_df) + out_graph: nx.Graph = _create_nx_graph(graph_type) + + in_attributes = ( + _get_node_attributes(kwargs) if type == "node" else _get_edge_attributes(kwargs) + ) + + # At this point, _get_node_attributes and _get_edge_attributes have already validated + id_col = in_attributes.get( + "id", in_attributes.get("label", in_attributes.get("name", None)) + ) + source_col = in_attributes.get("source", None) + target_col = in_attributes.get("target", None) + + for _, row in progress_iterable(input_df.iterrows(), callbacks.progress, num_total): + item_attributes = { + clean_str(key): _clean_value(row[value]) + for key, value in in_attributes.items() + if value in row + } + if type == "node": + id = clean_str(row[id_col]) + out_graph.add_node(id, **item_attributes) + elif type == "edge": + source = clean_str(row[source_col]) + target = clean_str(row[target_col]) + out_graph.add_edge(source, target, **item_attributes) + + graphml_string = "".join(nx.generate_graphml(out_graph)) + output_df = pd.DataFrame([{to: graphml_string}]) + return TableContainer(table=output_df) + + +def _clean_value(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return clean_str(value) + + msg = f"Value must be a string or None, got {type(value)}" + raise TypeError(msg) + + +def _get_node_attributes(args: dict[str, Any]) -> dict[str, Any]: + mapping = _get_attribute_column_mapping( + args.get("attributes", DEFAULT_NODE_ATTRIBUTES) + ) + if "id" not in mapping and "label" not in mapping and "name" not in mapping: + msg = "You must specify an id, label, or name column in the node attributes" + raise ValueError(msg) + return mapping + + +def _get_edge_attributes(args: dict[str, Any]) -> dict[str, Any]: + mapping = _get_attribute_column_mapping( + args.get("attributes", DEFAULT_EDGE_ATTRIBUTES) + ) + if "source" not in mapping or "target" not in mapping: + msg = "You must specify a source and target column in the edge attributes" + raise ValueError(msg) + return mapping + + +def _get_attribute_column_mapping( + in_attributes: dict[str, Any] | list[str], +) -> dict[str, str]: + # Its already a attribute: column dict + if isinstance(in_attributes, dict): + return { + **in_attributes, + } + + return {attrib: attrib for attrib in in_attributes} + + +def _create_nx_graph(graph_type: str) -> nx.Graph: + if graph_type == "directed": + return nx.DiGraph() + + return nx.Graph() diff --git a/func-app/graphrag/index/verbs/graph/embed/__init__.py b/func-app/graphrag/index/verbs/graph/embed/__init__.py new file mode 100644 index 0000000000..4ca8168c3c --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/embed/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph embed package root.""" + +from .embed_graph import EmbedGraphStrategyType, embed_graph + +__all__ = ["EmbedGraphStrategyType", "embed_graph"] diff --git a/func-app/graphrag/index/verbs/graph/embed/embed_graph.py b/func-app/graphrag/index/verbs/graph/embed/embed_graph.py new file mode 100644 index 0000000000..8691d343f0 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/embed/embed_graph.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing embed_graph and run_embeddings methods definition.""" + +from enum import Enum +from typing import Any, cast + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, derive_from_rows, verb + +from graphrag.index.utils import load_graph + +from .typing import NodeEmbeddings + + +class EmbedGraphStrategyType(str, Enum): + """EmbedGraphStrategyType class definition.""" + + node2vec = "node2vec" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="embed_graph") +async def embed_graph( + input: VerbInput, + callbacks: VerbCallbacks, + strategy: dict[str, Any], + column: str, + to: str, + **kwargs, +) -> TableContainer: + """ + Embed a graph into a vector space. The graph is expected to be in graphml format. The verb outputs a new column containing a mapping between node_id and vector. + + ## Usage + ```yaml + verb: embed_graph + args: + column: clustered_graph # The name of the column containing the graph, should be a graphml graph + to: embeddings # The name of the column to output the embeddings to + strategy: # See strategies section below + ``` + + ## Strategies + The embed_graph verb uses a strategy to embed the graph. The strategy is an object which defines the strategy to use. The following strategies are available: + + ### node2vec + This strategy uses the node2vec algorithm to embed a graph. The strategy config is as follows: + + ```yaml + strategy: + type: node2vec + dimensions: 1536 # Optional, The number of dimensions to use for the embedding, default: 1536 + num_walks: 10 # Optional, The number of walks to use for the embedding, default: 10 + walk_length: 40 # Optional, The walk length to use for the embedding, default: 40 + window_size: 2 # Optional, The window size to use for the embedding, default: 2 + iterations: 3 # Optional, The number of iterations to use for the embedding, default: 3 + random_seed: 86 # Optional, The random seed to use for the embedding, default: 86 + ``` + """ + output_df = cast(pd.DataFrame, input.get_input()) + + strategy_type = strategy.get("type", EmbedGraphStrategyType.node2vec) + strategy_args = {**strategy} + + async def run_strategy(row): # noqa RUF029 async is required for interface + return run_embeddings(strategy_type, cast(Any, row[column]), strategy_args) + + results = await derive_from_rows( + output_df, + run_strategy, + callbacks=callbacks, + num_threads=kwargs.get("num_threads", None), + ) + output_df[to] = list(results) + return TableContainer(table=output_df) + + +def run_embeddings( + strategy: EmbedGraphStrategyType, + graphml_or_graph: str | nx.Graph, + args: dict[str, Any], +) -> NodeEmbeddings: + """Run embeddings method definition.""" + graph = load_graph(graphml_or_graph) + match strategy: + case EmbedGraphStrategyType.node2vec: + from .strategies.node_2_vec import run as run_node_2_vec + + return run_node_2_vec(graph, args) + case _: + msg = f"Unknown strategy {strategy}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/graph/embed/strategies/__init__.py b/func-app/graphrag/index/verbs/graph/embed/strategies/__init__.py new file mode 100644 index 0000000000..ef85198eb7 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/embed/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Text Embedding strategies.""" diff --git a/func-app/graphrag/index/verbs/graph/embed/strategies/node_2_vec.py b/func-app/graphrag/index/verbs/graph/embed/strategies/node_2_vec.py new file mode 100644 index 0000000000..eb329519ed --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/embed/strategies/node_2_vec.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run method definition.""" + +from typing import Any + +import networkx as nx + +from graphrag.index.graph.embedding import embed_nod2vec +from graphrag.index.graph.utils import stable_largest_connected_component +from graphrag.index.verbs.graph.embed.typing import NodeEmbeddings + + +def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings: + """Run method definition.""" + if args.get("use_lcc", True): + graph = stable_largest_connected_component(graph) + + # create graph embedding using node2vec + embeddings = embed_nod2vec( + graph=graph, + dimensions=args.get("dimensions", 1536), + num_walks=args.get("num_walks", 10), + walk_length=args.get("walk_length", 40), + window_size=args.get("window_size", 2), + iterations=args.get("iterations", 3), + random_seed=args.get("random_seed", 86), + ) + + pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True) + sorted_pairs = sorted(pairs, key=lambda x: x[0]) + + return dict(sorted_pairs) diff --git a/func-app/graphrag/index/verbs/graph/embed/typing.py b/func-app/graphrag/index/verbs/graph/embed/typing.py new file mode 100644 index 0000000000..fea792c9b1 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/embed/typing.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing different lists and dictionaries.""" + +# Use this for now instead of a wrapper +from typing import Any + +NodeList = list[str] +EmbeddingList = list[Any] +NodeEmbeddings = dict[str, list[float]] +"""Label -> Embedding""" diff --git a/func-app/graphrag/index/verbs/graph/layout/__init__.py b/func-app/graphrag/index/verbs/graph/layout/__init__.py new file mode 100644 index 0000000000..74584f83ed --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/layout/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph layout package root.""" + +from .layout_graph import layout_graph + +__all__ = ["layout_graph"] diff --git a/func-app/graphrag/index/verbs/graph/layout/layout_graph.py b/func-app/graphrag/index/verbs/graph/layout/layout_graph.py new file mode 100644 index 0000000000..e1b55b1183 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/layout/layout_graph.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing layout_graph, _run_layout and _apply_layout_to_graph methods definition.""" + +from enum import Enum +from typing import Any, cast + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_callback, verb + +from graphrag.index.graph.visualization import GraphLayout +from graphrag.index.utils import load_graph +from graphrag.index.verbs.graph.embed.typing import NodeEmbeddings + + +class LayoutGraphStrategyType(str, Enum): + """LayoutGraphStrategyType class definition.""" + + umap = "umap" + zero = "zero" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="layout_graph") +def layout_graph( + input: VerbInput, + callbacks: VerbCallbacks, + strategy: dict[str, Any], + embeddings_column: str, + graph_column: str, + to: str, + graph_to: str | None = None, + **_kwargs: dict, +) -> TableContainer: + """ + Apply a layout algorithm to a graph. The graph is expected to be in graphml format. The verb outputs a new column containing the laid out graph. + + ## Usage + ```yaml + verb: layout_graph + args: + graph_column: clustered_graph # The name of the column containing the graph, should be a graphml graph + embeddings_column: embeddings # The name of the column containing the embeddings + to: node_positions # The name of the column to output the node positions to + graph_to: positioned_graph # The name of the column to output the positioned graph to + strategy: # See strategies section below + ``` + + ## Strategies + The layout graph verb uses a strategy to layout the graph. The strategy is a json object which defines the strategy to use. The following strategies are available: + + ### umap + This strategy uses the umap algorithm to layout a graph. The strategy config is as follows: + ```yaml + strategy: + type: umap + n_neighbors: 5 # Optional, The number of neighbors to use for the umap algorithm, default: 5 + min_dist: 0.75 # Optional, The min distance to use for the umap algorithm, default: 0.75 + ``` + """ + output_df = cast(pd.DataFrame, input.get_input()) + + num_items = len(output_df) + strategy_type = strategy.get("type", LayoutGraphStrategyType.umap) + strategy_args = {**strategy} + + has_embeddings = embeddings_column in output_df.columns + + layouts = output_df.apply( + progress_callback( + lambda row: _run_layout( + strategy_type, + row[graph_column], + row[embeddings_column] if has_embeddings else {}, + strategy_args, + callbacks, + ), + callbacks.progress, + num_items, + ), + axis=1, + ) + output_df[to] = layouts.apply(lambda layout: [pos.to_pandas() for pos in layout]) + if graph_to is not None: + output_df[graph_to] = output_df.apply( + lambda row: _apply_layout_to_graph( + row[graph_column], cast(GraphLayout, layouts[row.name]) + ), + axis=1, + ) + return TableContainer(table=output_df) + + +def _run_layout( + strategy: LayoutGraphStrategyType, + graphml_or_graph: str | nx.Graph, + embeddings: NodeEmbeddings, + args: dict[str, Any], + reporter: VerbCallbacks, +) -> GraphLayout: + graph = load_graph(graphml_or_graph) + match strategy: + case LayoutGraphStrategyType.umap: + from .methods.umap import run as run_umap + + return run_umap( + graph, + embeddings, + args, + lambda e, stack, d: reporter.error("Error in Umap", e, stack, d), + ) + case LayoutGraphStrategyType.zero: + from .methods.zero import run as run_zero + + return run_zero( + graph, + args, + lambda e, stack, d: reporter.error("Error in Zero", e, stack, d), + ) + case _: + msg = f"Unknown strategy {strategy}" + raise ValueError(msg) + + +def _apply_layout_to_graph( + graphml_or_graph: str | nx.Graph, layout: GraphLayout +) -> str: + graph = load_graph(graphml_or_graph) + for node_position in layout: + if node_position.label in graph.nodes: + graph.nodes[node_position.label]["x"] = node_position.x + graph.nodes[node_position.label]["y"] = node_position.y + graph.nodes[node_position.label]["size"] = node_position.size + return "\n".join(nx.generate_graphml(graph)) diff --git a/func-app/graphrag/index/verbs/graph/layout/methods/__init__.py b/func-app/graphrag/index/verbs/graph/layout/methods/__init__.py new file mode 100644 index 0000000000..5d5054122b --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/layout/methods/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Graph Layout Methods.""" diff --git a/func-app/graphrag/index/verbs/graph/layout/methods/umap.py b/func-app/graphrag/index/verbs/graph/layout/methods/umap.py new file mode 100644 index 0000000000..a4bc7c2818 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/layout/methods/umap.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _create_node_position methods definitions.""" + +import logging +import traceback +from typing import Any + +import networkx as nx +import numpy as np + +from graphrag.index.graph.visualization import ( + GraphLayout, + NodePosition, + compute_umap_positions, +) +from graphrag.index.typing import ErrorHandlerFn +from graphrag.index.verbs.graph.embed.typing import NodeEmbeddings + +# TODO: This could be handled more elegantly, like what columns to use +# for "size" or "cluster" +# We could also have a boolean to indicate to use node sizes or clusters + +log = logging.getLogger(__name__) + + +def run( + graph: nx.Graph, + embeddings: NodeEmbeddings, + args: dict[str, Any], + on_error: ErrorHandlerFn, +) -> GraphLayout: + """Run method definition.""" + node_clusters = [] + node_sizes = [] + + embeddings = _filter_raw_embeddings(embeddings) + nodes = list(embeddings.keys()) + embedding_vectors = [embeddings[node_id] for node_id in nodes] + + for node_id in nodes: + node = graph.nodes[node_id] + cluster = node.get("cluster", node.get("community", -1)) + node_clusters.append(cluster) + size = node.get("degree", node.get("size", 0)) + node_sizes.append(size) + + additional_args = {} + if len(node_clusters) > 0: + additional_args["node_categories"] = node_clusters + if len(node_sizes) > 0: + additional_args["node_sizes"] = node_sizes + + try: + return compute_umap_positions( + embedding_vectors=np.array(embedding_vectors), + node_labels=nodes, + **additional_args, + min_dist=args.get("min_dist", 0.75), + n_neighbors=args.get("n_neighbors", 5), + ) + except Exception as e: + log.exception("Error running UMAP") + on_error(e, traceback.format_exc(), None) + # Umap may fail due to input sparseness or memory pressure. + # For now, in these cases, we'll just return a layout with all nodes at (0, 0) + result = [] + for i in range(len(nodes)): + cluster = node_clusters[i] if len(node_clusters) > 0 else 1 + result.append( + NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster)) + ) + return result + + +def _filter_raw_embeddings(embeddings: NodeEmbeddings) -> NodeEmbeddings: + return { + node_id: embedding + for node_id, embedding in embeddings.items() + if embedding is not None + } diff --git a/func-app/graphrag/index/verbs/graph/layout/methods/zero.py b/func-app/graphrag/index/verbs/graph/layout/methods/zero.py new file mode 100644 index 0000000000..f41d2d4ca4 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/layout/methods/zero.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _create_node_position methods definitions.""" + +import logging +import traceback +from typing import Any + +import networkx as nx + +from graphrag.index.graph.visualization import ( + GraphLayout, + NodePosition, + get_zero_positions, +) +from graphrag.index.typing import ErrorHandlerFn + +# TODO: This could be handled more elegantly, like what columns to use +# for "size" or "cluster" +# We could also have a boolean to indicate to use node sizes or clusters + +log = logging.getLogger(__name__) + + +def run( + graph: nx.Graph, + _args: dict[str, Any], + on_error: ErrorHandlerFn, +) -> GraphLayout: + """Run method definition.""" + node_clusters = [] + node_sizes = [] + + nodes = list(graph.nodes) + + for node_id in nodes: + node = graph.nodes[node_id] + cluster = node.get("cluster", node.get("community", -1)) + node_clusters.append(cluster) + size = node.get("degree", node.get("size", 0)) + node_sizes.append(size) + + additional_args = {} + if len(node_clusters) > 0: + additional_args["node_categories"] = node_clusters + if len(node_sizes) > 0: + additional_args["node_sizes"] = node_sizes + + try: + return get_zero_positions(node_labels=nodes, **additional_args) + except Exception as e: + log.exception("Error running zero-position") + on_error(e, traceback.format_exc(), None) + # Umap may fail due to input sparseness or memory pressure. + # For now, in these cases, we'll just return a layout with all nodes at (0, 0) + result = [] + for i in range(len(nodes)): + cluster = node_clusters[i] if len(node_clusters) > 0 else 1 + result.append( + NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster)) + ) + return result diff --git a/func-app/graphrag/index/verbs/graph/merge/__init__.py b/func-app/graphrag/index/verbs/graph/merge/__init__.py new file mode 100644 index 0000000000..f718827942 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/merge/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph merge package root.""" + +from .merge_graphs import merge_graphs + +__all__ = ["merge_graphs"] diff --git a/func-app/graphrag/index/verbs/graph/merge/defaults.py b/func-app/graphrag/index/verbs/graph/merge/defaults.py new file mode 100644 index 0000000000..80c60331c6 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/merge/defaults.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing DEFAULT_NODE_OPERATIONS, DEFAULT_EDGE_OPERATIONS and DEFAULT_CONCAT_SEPARATOR values definition.""" + +from .typing import BasicMergeOperation + +DEFAULT_NODE_OPERATIONS = { + "*": { + "operation": BasicMergeOperation.Replace, + } +} + +DEFAULT_EDGE_OPERATIONS = { + "*": { + "operation": BasicMergeOperation.Replace, + }, + "weight": "sum", +} + +DEFAULT_CONCAT_SEPARATOR = "," diff --git a/func-app/graphrag/index/verbs/graph/merge/merge_graphs.py b/func-app/graphrag/index/verbs/graph/merge/merge_graphs.py new file mode 100644 index 0000000000..8ab3fa47f7 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/merge/merge_graphs.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing merge_graphs, merge_nodes, merge_edges, merge_attributes, apply_merge_operation and _get_detailed_attribute_merge_operation methods definitions.""" + +from typing import Any, cast + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb + +from graphrag.index.utils import load_graph + +from .defaults import ( + DEFAULT_CONCAT_SEPARATOR, + DEFAULT_EDGE_OPERATIONS, + DEFAULT_NODE_OPERATIONS, +) +from .typing import ( + BasicMergeOperation, + DetailedAttributeMergeOperation, + NumericOperation, + StringOperation, +) + + +@verb(name="merge_graphs") +def merge_graphs( + input: VerbInput, + callbacks: VerbCallbacks, + column: str, + to: str, + nodes: dict[str, Any] = DEFAULT_NODE_OPERATIONS, + edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS, + **_kwargs, +) -> TableContainer: + """ + Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph. + + > Note: This will merge all rows into a single graph. + + ## Usage + ```yaml + verb: merge_graph + args: + column: clustered_graph # The name of the column containing the graph, should be a graphml graph + to: merged_graph # The name of the column to output the merged graph to + nodes: # See node operations section below + edges: # See edge operations section below + ``` + + ## Node Operations + The merge graph verb can perform operations on the nodes of the graph. + + ### Usage + ```yaml + nodes: + : + ... for each attribute or use the special value "*" for all attributes + ``` + + ## Edge Operations + The merge graph verb can perform operations on the nodes of the graph. + + ### Usage + ```yaml + edges: + : + ... for each attribute or use the special value "*" for all attributes + ``` + + ## Operations + The merge graph verb can perform operations on the nodes and edges of the graph. The following operations are available: + + - __replace__: This operation replaces the attribute with the last value seen. + - __skip__: This operation skips the attribute, and just uses the first value seen. + - __concat__: This operation concatenates the attribute with the last value seen. + - __sum__: This operation sums the attribute with the last value seen. + - __max__: This operation takes the max of the attribute with the last value seen. + max + - __min__: This operation takes the min of the attribute with the last value seen. + - __average__: This operation takes the mean of the attribute with the last value seen. + - __multiply__: This operation multiplies the attribute with the last value seen. + """ + input_df = input.get_input() + output = pd.DataFrame() + + node_ops = { + attrib: _get_detailed_attribute_merge_operation(value) + for attrib, value in nodes.items() + } + edge_ops = { + attrib: _get_detailed_attribute_merge_operation(value) + for attrib, value in edges.items() + } + + mega_graph = nx.Graph() + num_total = len(input_df) + for graphml in progress_iterable(input_df[column], callbacks.progress, num_total): + graph = load_graph(cast(str | nx.Graph, graphml)) + merge_nodes(mega_graph, graph, node_ops) + merge_edges(mega_graph, graph, edge_ops) + + output[to] = ["\n".join(nx.generate_graphml(mega_graph))] + + return TableContainer(table=output) + + +def merge_nodes( + target: nx.Graph, + subgraph: nx.Graph, + node_ops: dict[str, DetailedAttributeMergeOperation], +): + """Merge nodes from subgraph into target using the operations defined in node_ops.""" + for node in subgraph.nodes: + if node not in target.nodes: + target.add_node(node, **(subgraph.nodes[node] or {})) + else: + merge_attributes(target.nodes[node], subgraph.nodes[node], node_ops) + + +def merge_edges( + target_graph: nx.Graph, + subgraph: nx.Graph, + edge_ops: dict[str, DetailedAttributeMergeOperation], +): + """Merge edges from subgraph into target using the operations defined in edge_ops.""" + for source, target, edge_data in subgraph.edges(data=True): # type: ignore + if not target_graph.has_edge(source, target): + target_graph.add_edge(source, target, **(edge_data or {})) + else: + merge_attributes(target_graph.edges[(source, target)], edge_data, edge_ops) + + +def merge_attributes( + target_item: dict[str, Any] | None, + source_item: dict[str, Any] | None, + ops: dict[str, DetailedAttributeMergeOperation], +): + """Merge attributes from source_item into target_item using the operations defined in ops.""" + source_item = source_item or {} + target_item = target_item or {} + for op_attrib, op in ops.items(): + if op_attrib == "*": + for attrib in source_item: + # If there is a specific handler for this attribute, use it + # i.e. * provides a default, but you can override it + if attrib not in ops: + apply_merge_operation(target_item, source_item, attrib, op) + else: + if op_attrib in source_item or op_attrib in target_item: + apply_merge_operation(target_item, source_item, op_attrib, op) + + +def apply_merge_operation( + target_item: dict[str, Any] | None, + source_item: dict[str, Any] | None, + attrib: str, + op: DetailedAttributeMergeOperation, +): + """Apply the merge operation to the attribute.""" + source_item = source_item or {} + target_item = target_item or {} + + if ( + op.operation == BasicMergeOperation.Replace + or op.operation == StringOperation.Replace + ): + target_item[attrib] = source_item.get(attrib, None) or "" + elif ( + op.operation == BasicMergeOperation.Skip or op.operation == StringOperation.Skip + ): + target_item[attrib] = target_item.get(attrib, None) or "" + elif op.operation == StringOperation.Concat: + separator = op.separator or DEFAULT_CONCAT_SEPARATOR + target_attrib = target_item.get(attrib, "") or "" + source_attrib = source_item.get(attrib, "") or "" + target_item[attrib] = f"{target_attrib}{separator}{source_attrib}" + if op.distinct: + # TODO: Slow + target_item[attrib] = separator.join( + sorted(set(target_item[attrib].split(separator))) + ) + + # We're assuming that the attribute is numeric + elif op.operation == NumericOperation.Sum: + target_item[attrib] = (target_item.get(attrib, 0) or 0) + ( + source_item.get(attrib, 0) or 0 + ) + elif op.operation == NumericOperation.Average: + target_item[attrib] = ( + (target_item.get(attrib, 0) or 0) + (source_item.get(attrib, 0) or 0) + ) / 2 + elif op.operation == NumericOperation.Max: + target_item[attrib] = max( + (target_item.get(attrib, 0) or 0), (source_item.get(attrib, 0) or 0) + ) + elif op.operation == NumericOperation.Min: + target_item[attrib] = min( + (target_item.get(attrib, 0) or 0), (source_item.get(attrib, 0) or 0) + ) + elif op.operation == NumericOperation.Multiply: + target_item[attrib] = (target_item.get(attrib, 1) or 1) * ( + source_item.get(attrib, 1) or 1 + ) + else: + msg = f"Invalid operation {op.operation}" + raise ValueError(msg) + + +def _get_detailed_attribute_merge_operation( + value: str | dict[str, Any], +) -> DetailedAttributeMergeOperation: + """Normalize the AttributeMergeOperation into a DetailedAttributeMergeOperation.""" + if isinstance(value, str): + return DetailedAttributeMergeOperation(operation=value) + return DetailedAttributeMergeOperation(**value) diff --git a/func-app/graphrag/index/verbs/graph/merge/typing.py b/func-app/graphrag/index/verbs/graph/merge/typing.py new file mode 100644 index 0000000000..0e534f516c --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/merge/typing.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'BasicMergeOperation', 'StringOperation', 'NumericOperation' and 'DetailedAttributeMergeOperation' models.""" + +from dataclasses import dataclass +from enum import Enum + + +class BasicMergeOperation(str, Enum): + """Basic Merge Operation class definition.""" + + Replace = "replace" + Skip = "skip" + + +class StringOperation(str, Enum): + """String Operation class definition.""" + + Concat = "concat" + Replace = "replace" + Skip = "skip" + + +class NumericOperation(str, Enum): + """Numeric Operation class definition.""" + + Sum = "sum" + Average = "average" + Max = "max" + Min = "min" + Multiply = "multiply" + Replace = "replace" + Skip = "skip" + + +@dataclass +class DetailedAttributeMergeOperation: + """Detailed attribute merge operation class definition.""" + + operation: str # StringOperation | NumericOperation + + # concat + separator: str | None = None + delimiter: str | None = None + distinct: bool = False + + +AttributeMergeOperation = str | DetailedAttributeMergeOperation diff --git a/func-app/graphrag/index/verbs/graph/report/__init__.py b/func-app/graphrag/index/verbs/graph/report/__init__.py new file mode 100644 index 0000000000..e47d9ccef5 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph report package root.""" + +from .create_community_reports import ( + CreateCommunityReportsStrategyType, + create_community_reports, +) +from .prepare_community_reports import prepare_community_reports +from .prepare_community_reports_claims import prepare_community_reports_claims +from .prepare_community_reports_edges import prepare_community_reports_edges +from .prepare_community_reports_nodes import prepare_community_reports_nodes +from .restore_community_hierarchy import restore_community_hierarchy + +__all__ = [ + "CreateCommunityReportsStrategyType", + "create_community_reports", + "create_community_reports", + "prepare_community_reports", + "prepare_community_reports_claims", + "prepare_community_reports_edges", + "prepare_community_reports_nodes", + "restore_community_hierarchy", +] diff --git a/func-app/graphrag/index/verbs/graph/report/create_community_reports.py b/func-app/graphrag/index/verbs/graph/report/create_community_reports.py new file mode 100644 index 0000000000..c67d5107e8 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/create_community_reports.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_community_reports and load_strategy methods definition.""" + +import logging +from enum import Enum +from typing import cast + +import pandas as pd +from datashaper import ( + AsyncType, + NoopVerbCallbacks, + TableContainer, + VerbCallbacks, + VerbInput, + derive_from_rows, + progress_ticker, + verb, +) + +import graphrag.config.defaults as defaults +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.index.cache import PipelineCache +from graphrag.index.graph.extractors.community_reports import ( + get_levels, + prep_community_report_context, +) +from graphrag.index.utils.ds_util import get_required_input_table + +from .strategies.typing import CommunityReport, CommunityReportsStrategy + +log = logging.getLogger(__name__) + + +class CreateCommunityReportsStrategyType(str, Enum): + """CreateCommunityReportsStrategyType class definition.""" + + graph_intelligence = "graph_intelligence" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="create_community_reports") +async def create_community_reports( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + strategy: dict, + async_mode: AsyncType = AsyncType.AsyncIO, + num_threads: int = 4, + **_kwargs, +) -> TableContainer: + """Generate entities for each row, and optionally a graph of those entities.""" + log.debug("create_community_reports strategy=%s", strategy) + local_contexts = cast(pd.DataFrame, input.get_input()) + nodes_ctr = get_required_input_table(input, "nodes") + nodes = cast(pd.DataFrame, nodes_ctr.table) + community_hierarchy_ctr = get_required_input_table(input, "community_hierarchy") + community_hierarchy = cast(pd.DataFrame, community_hierarchy_ctr.table) + + levels = get_levels(nodes) + reports: list[CommunityReport | None] = [] + tick = progress_ticker(callbacks.progress, len(local_contexts)) + runner = load_strategy(strategy["type"]) + + for level in levels: + level_contexts = prep_community_report_context( + pd.DataFrame(reports), + local_context_df=local_contexts, + community_hierarchy_df=community_hierarchy, + level=level, + max_tokens=strategy.get( + "max_input_tokens", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH + ), + ) + + async def run_generate(record): + result = await _generate_report( + runner, + community_id=record[schemas.NODE_COMMUNITY], + community_level=record[schemas.COMMUNITY_LEVEL], + community_context=record[schemas.CONTEXT_STRING], + cache=cache, + callbacks=callbacks, + strategy=strategy, + ) + tick() + return result + + local_reports = await derive_from_rows( + level_contexts, + run_generate, + callbacks=NoopVerbCallbacks(), + num_threads=num_threads, + scheduling_type=async_mode, + ) + reports.extend([lr for lr in local_reports if lr is not None]) + + return TableContainer(table=pd.DataFrame(reports)) + + +async def _generate_report( + runner: CommunityReportsStrategy, + cache: PipelineCache, + callbacks: VerbCallbacks, + strategy: dict, + community_id: int | str, + community_level: int, + community_context: str, +) -> CommunityReport | None: + """Generate a report for a single community.""" + return await runner( + community_id, community_context, community_level, callbacks, cache, strategy + ) + + +def load_strategy( + strategy: CreateCommunityReportsStrategyType, +) -> CommunityReportsStrategy: + """Load strategy method definition.""" + match strategy: + case CreateCommunityReportsStrategyType.graph_intelligence: + from .strategies.graph_intelligence import run + + return run + case _: + msg = f"Unknown strategy: {strategy}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/graph/report/prepare_community_reports.py b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports.py new file mode 100644 index 0000000000..3c9ebd451a --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_community_reports and load_strategy methods definition.""" + +import logging +from typing import cast + +import pandas as pd +from datashaper import ( + TableContainer, + VerbCallbacks, + VerbInput, + progress_iterable, + verb, +) + +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.index.graph.extractors.community_reports import ( + filter_claims_to_nodes, + filter_edges_to_nodes, + filter_nodes_to_level, + get_levels, + set_context_exceeds_flag, + set_context_size, + sort_context, +) +from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table + +log = logging.getLogger(__name__) + + +@verb(name="prepare_community_reports") +def prepare_community_reports( + input: VerbInput, + callbacks: VerbCallbacks, + max_tokens: int = 16_000, + **_kwargs, +) -> TableContainer: + """Generate entities for each row, and optionally a graph of those entities.""" + # Prepare Community Reports + node_df = cast(pd.DataFrame, get_required_input_table(input, "nodes").table) + edge_df = cast(pd.DataFrame, get_required_input_table(input, "edges").table) + claim_df = get_named_input_table(input, "claims") + if claim_df is not None: + claim_df = cast(pd.DataFrame, claim_df.table) + + levels = get_levels(node_df, schemas.NODE_LEVEL) + dfs = [] + + for level in progress_iterable(levels, callbacks.progress, len(levels)): + communities_at_level_df = _prepare_reports_at_level( + node_df, edge_df, claim_df, level, max_tokens + ) + dfs.append(communities_at_level_df) + + # build initial local context for all communities + return TableContainer(table=pd.concat(dfs)) + + +def _prepare_reports_at_level( + node_df: pd.DataFrame, + edge_df: pd.DataFrame, + claim_df: pd.DataFrame | None, + level: int, + max_tokens: int = 16_000, + community_id_column: str = schemas.COMMUNITY_ID, + node_id_column: str = schemas.NODE_ID, + node_name_column: str = schemas.NODE_NAME, + node_details_column: str = schemas.NODE_DETAILS, + node_level_column: str = schemas.NODE_LEVEL, + node_degree_column: str = schemas.NODE_DEGREE, + node_community_column: str = schemas.NODE_COMMUNITY, + edge_id_column: str = schemas.EDGE_ID, + edge_source_column: str = schemas.EDGE_SOURCE, + edge_target_column: str = schemas.EDGE_TARGET, + edge_degree_column: str = schemas.EDGE_DEGREE, + edge_details_column: str = schemas.EDGE_DETAILS, + claim_id_column: str = schemas.CLAIM_ID, + claim_subject_column: str = schemas.CLAIM_SUBJECT, + claim_details_column: str = schemas.CLAIM_DETAILS, +): + def get_edge_details(node_df: pd.DataFrame, edge_df: pd.DataFrame, name_col: str): + return node_df.merge( + cast( + pd.DataFrame, + edge_df[[name_col, schemas.EDGE_DETAILS]], + ).rename(columns={name_col: schemas.NODE_NAME}), + on=schemas.NODE_NAME, + how="left", + ) + + level_node_df = filter_nodes_to_level(node_df, level) + log.info("Number of nodes at level=%s => %s", level, len(level_node_df)) + nodes = level_node_df[node_name_column].tolist() + + # Filter edges & claims to those containing the target nodes + level_edge_df = filter_edges_to_nodes(edge_df, nodes) + level_claim_df = ( + filter_claims_to_nodes(claim_df, nodes) if claim_df is not None else None + ) + + # concat all edge details per node + merged_node_df = pd.concat( + [ + get_edge_details(level_node_df, level_edge_df, edge_source_column), + get_edge_details(level_node_df, level_edge_df, edge_target_column), + ], + axis=0, + ) + merged_node_df = ( + merged_node_df.groupby([ + node_name_column, + node_community_column, + node_degree_column, + node_level_column, + ]) + .agg({node_details_column: "first", edge_details_column: list}) + .reset_index() + ) + + # concat claim details per node + if level_claim_df is not None: + merged_node_df = merged_node_df.merge( + cast( + pd.DataFrame, + level_claim_df[[claim_subject_column, claim_details_column]], + ).rename(columns={claim_subject_column: node_name_column}), + on=node_name_column, + how="left", + ) + merged_node_df = ( + merged_node_df.groupby([ + node_name_column, + node_community_column, + node_level_column, + node_degree_column, + ]) + .agg({ + node_details_column: "first", + edge_details_column: "first", + **({claim_details_column: list} if level_claim_df is not None else {}), + }) + .reset_index() + ) + + # concat all node details, including name, degree, node_details, edge_details, and claim_details + merged_node_df[schemas.ALL_CONTEXT] = merged_node_df.apply( + lambda x: { + node_name_column: x[node_name_column], + node_degree_column: x[node_degree_column], + node_details_column: x[node_details_column], + edge_details_column: x[edge_details_column], + claim_details_column: x[claim_details_column] + if level_claim_df is not None + else [], + }, + axis=1, + ) + + # group all node details by community + community_df = ( + merged_node_df.groupby(node_community_column) + .agg({schemas.ALL_CONTEXT: list}) + .reset_index() + ) + community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply( + lambda x: sort_context( + x, + node_id_column=node_id_column, + node_name_column=node_name_column, + node_details_column=node_details_column, + edge_id_column=edge_id_column, + edge_details_column=edge_details_column, + edge_degree_column=edge_degree_column, + edge_source_column=edge_source_column, + edge_target_column=edge_target_column, + claim_id_column=claim_id_column, + claim_details_column=claim_details_column, + community_id_column=community_id_column, + ) + ) + set_context_size(community_df) + set_context_exceeds_flag(community_df, max_tokens) + + community_df[schemas.COMMUNITY_LEVEL] = level + return community_df diff --git a/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_claims.py b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_claims.py new file mode 100644 index 0000000000..aa9a790772 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_claims.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.graph.extractors.community_reports.schemas import ( + CLAIM_DESCRIPTION, + CLAIM_DETAILS, + CLAIM_ID, + CLAIM_STATUS, + CLAIM_SUBJECT, + CLAIM_TYPE, +) + +_MISSING_DESCRIPTION = "No Description" + + +@verb(name="prepare_community_reports_claims") +def prepare_community_reports_claims( + input: VerbInput, + to: str = CLAIM_DETAILS, + id_column: str = CLAIM_ID, + description_column: str = CLAIM_DESCRIPTION, + subject_column: str = CLAIM_SUBJECT, + type_column: str = CLAIM_TYPE, + status_column: str = CLAIM_STATUS, + **_kwargs, +) -> TableContainer: + """Merge claim details into an object.""" + claim_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()) + claim_df = claim_df.fillna(value={description_column: _MISSING_DESCRIPTION}) + + # merge values of five columns into a map column + claim_df[to] = claim_df.apply( + lambda x: { + id_column: x[id_column], + subject_column: x[subject_column], + type_column: x[type_column], + status_column: x[status_column], + description_column: x[description_column], + }, + axis=1, + ) + + return TableContainer(table=claim_df) diff --git a/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_edges.py b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_edges.py new file mode 100644 index 0000000000..b568aba006 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_edges.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.graph.extractors.community_reports.schemas import ( + EDGE_DEGREE, + EDGE_DESCRIPTION, + EDGE_DETAILS, + EDGE_ID, + EDGE_SOURCE, + EDGE_TARGET, +) + +_MISSING_DESCRIPTION = "No Description" + + +@verb(name="prepare_community_reports_edges") +def prepare_community_reports_edges( + input: VerbInput, + to: str = EDGE_DETAILS, + id_column: str = EDGE_ID, + source_column: str = EDGE_SOURCE, + target_column: str = EDGE_TARGET, + description_column: str = EDGE_DESCRIPTION, + degree_column: str = EDGE_DEGREE, + **_kwargs, +) -> TableContainer: + """Merge edge details into an object.""" + edge_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()).fillna( + value={description_column: _MISSING_DESCRIPTION} + ) + edge_df[to] = edge_df.apply( + lambda x: { + id_column: x[id_column], + source_column: x[source_column], + target_column: x[target_column], + description_column: x[description_column], + degree_column: x[degree_column], + }, + axis=1, + ) + return TableContainer(table=edge_df) diff --git a/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_nodes.py b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_nodes.py new file mode 100644 index 0000000000..f159c125ee --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_nodes.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.graph.extractors.community_reports.schemas import ( + NODE_DEGREE, + NODE_DESCRIPTION, + NODE_DETAILS, + NODE_ID, + NODE_NAME, +) + +_MISSING_DESCRIPTION = "No Description" + + +@verb(name="prepare_community_reports_nodes") +def prepare_community_reports_nodes( + input: VerbInput, + to: str = NODE_DETAILS, + id_column: str = NODE_ID, + name_column: str = NODE_NAME, + description_column: str = NODE_DESCRIPTION, + degree_column: str = NODE_DEGREE, + **_kwargs, +) -> TableContainer: + """Merge edge details into an object.""" + node_df = cast(pd.DataFrame, input.get_input()) + node_df = node_df.fillna(value={description_column: _MISSING_DESCRIPTION}) + + # merge values of four columns into a map column + node_df[to] = node_df.apply( + lambda x: { + id_column: x[id_column], + name_column: x[name_column], + description_column: x[description_column], + degree_column: x[degree_column], + }, + axis=1, + ) + return TableContainer(table=node_df) diff --git a/func-app/graphrag/index/verbs/graph/report/restore_community_hierarchy.py b/func-app/graphrag/index/verbs/graph/report/restore_community_hierarchy.py new file mode 100644 index 0000000000..437369f0e5 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/restore_community_hierarchy.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +import logging +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +import graphrag.index.graph.extractors.community_reports.schemas as schemas + +log = logging.getLogger(__name__) + + +@verb(name="restore_community_hierarchy") +def restore_community_hierarchy( + input: VerbInput, + name_column: str = schemas.NODE_NAME, + community_column: str = schemas.NODE_COMMUNITY, + level_column: str = schemas.NODE_LEVEL, + **_kwargs, +) -> TableContainer: + """Restore the community hierarchy from the node data.""" + node_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()) + community_df = ( + node_df.groupby([community_column, level_column]) + .agg({name_column: list}) + .reset_index() + ) + community_levels = {} + for _, row in community_df.iterrows(): + level = row[level_column] + name = row[name_column] + community = row[community_column] + + if community_levels.get(level) is None: + community_levels[level] = {} + community_levels[level][community] = name + + # get unique levels, sorted in ascending order + levels = sorted(community_levels.keys()) + + community_hierarchy = [] + + for idx in range(len(levels) - 1): + level = levels[idx] + log.debug("Level: %s", level) + next_level = levels[idx + 1] + current_level_communities = community_levels[level] + next_level_communities = community_levels[next_level] + log.debug( + "Number of communities at level %s: %s", + level, + len(current_level_communities), + ) + + for current_community in current_level_communities: + current_entities = current_level_communities[current_community] + + # loop through next level's communities to find all the subcommunities + entities_found = 0 + for next_level_community in next_level_communities: + next_entities = next_level_communities[next_level_community] + if set(next_entities).issubset(set(current_entities)): + community_hierarchy.append({ + community_column: current_community, + schemas.COMMUNITY_LEVEL: level, + schemas.SUB_COMMUNITY: next_level_community, + schemas.SUB_COMMUNITY_SIZE: len(next_entities), + }) + + entities_found += len(next_entities) + if entities_found == len(current_entities): + break + + return TableContainer(table=pd.DataFrame(community_hierarchy)) diff --git a/func-app/graphrag/index/verbs/graph/report/strategies/__init__.py b/func-app/graphrag/index/verbs/graph/report/strategies/__init__.py new file mode 100644 index 0000000000..87d1f9e252 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph report strategies package root.""" diff --git a/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/__init__.py b/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/__init__.py new file mode 100644 index 0000000000..7f51d7909b --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph report strategies graph intelligence package root.""" + +from .run_graph_intelligence import run + +__all__ = ["run"] diff --git a/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py b/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py new file mode 100644 index 0000000000..708d48d2b6 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing DEFAULT_CHUNK_SIZE and MOCK_RESPONSES definitions.""" + +import json + +DEFAULT_CHUNK_SIZE = 3000 +MOCK_RESPONSES = [ + json.dumps({ + "title": "", + "summary": "", + "rating": 2, + "rating_explanation": "", + "findings": [ + { + "summary": "", + "explanation": "", + "explanation": " CommunityReport | None: + """Run the graph intelligence entity extraction strategy.""" + llm_config = args.get( + "llm", {"type": LLMType.StaticResponse, "responses": MOCK_RESPONSES} + ) + llm_type = llm_config.get("type", LLMType.StaticResponse) + llm = load_llm( + "community_reporting", llm_type, reporter, pipeline_cache, llm_config + ) + return await _run_extractor(llm, community, input, level, args, reporter) + + +async def _run_extractor( + llm: CompletionLLM, + community: str | int, + input: str, + level: int, + args: StrategyConfig, + reporter: VerbCallbacks, +) -> CommunityReport | None: + # RateLimiter + rate_limiter = RateLimiter(rate=1, per=60) + extractor = CommunityReportsExtractor( + llm, + extraction_prompt=args.get("extraction_prompt", None), + max_report_length=args.get("max_report_length", None), + on_error=lambda e, stack, _data: reporter.error( + "Community Report Extraction Error", e, stack + ), + ) + + try: + await rate_limiter.acquire() + results = await extractor({"input_text": input}) + report = results.structured_output + if report is None or len(report.keys()) == 0: + log.warning("No report found for community: %s", community) + return None + + return CommunityReport( + community=community, + full_content=results.output, + level=level, + rank=_parse_rank(report), + title=report.get("title", f"Community Report: {community}"), + rank_explanation=report.get("rating_explanation", ""), + summary=report.get("summary", ""), + findings=report.get("findings", []), + full_content_json=json.dumps(report, indent=4), + ) + except Exception as e: + log.exception("Error processing community: %s", community) + reporter.error("Community Report Extraction Error", e, traceback.format_exc()) + return None + + +def _parse_rank(report: dict) -> float: + rank = report.get("rating", -1) + try: + return float(rank) + except ValueError: + log.exception("Error parsing rank: %s defaulting to -1", rank) + return -1 diff --git a/func-app/graphrag/index/verbs/graph/report/strategies/typing.py b/func-app/graphrag/index/verbs/graph/report/strategies/typing.py new file mode 100644 index 0000000000..087c724702 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/strategies/typing.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Finding' and 'CommunityReport' models.""" + +from collections.abc import Awaitable, Callable +from typing import Any + +from datashaper import VerbCallbacks +from typing_extensions import TypedDict + +from graphrag.index.cache import PipelineCache + +ExtractedEntity = dict[str, Any] +StrategyConfig = dict[str, Any] +RowContext = dict[str, Any] +EntityTypes = list[str] +Claim = dict[str, Any] + + +class Finding(TypedDict): + """Finding class definition.""" + + summary: str + explanation: str + + +class CommunityReport(TypedDict): + """Community report class definition.""" + + community: str | int + title: str + summary: str + full_content: str + full_content_json: str + rank: float + level: int + rank_explanation: str + findings: list[Finding] + + +CommunityReportsStrategy = Callable[ + [ + str | int, + str, + int, + VerbCallbacks, + PipelineCache, + StrategyConfig, + ], + Awaitable[CommunityReport | None], +] diff --git a/func-app/graphrag/index/verbs/graph/unpack.py b/func-app/graphrag/index/verbs/graph/unpack.py new file mode 100644 index 0000000000..ffb7f4b0a2 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/unpack.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing unpack_graph, _run_unpack, _unpack_nodes and _unpack_edges methods definition.""" + +from typing import Any, cast + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb + +from graphrag.index.utils import load_graph + +default_copy = ["level"] + + +@verb(name="unpack_graph") +def unpack_graph( + input: VerbInput, + callbacks: VerbCallbacks, + column: str, + type: str, # noqa A002 + copy: list[str] | None = None, + embeddings_column: str = "embeddings", + **kwargs, +) -> TableContainer: + """ + Unpack nodes or edges from a graphml graph, into a list of nodes or edges. + + This verb will create columns for each attribute in a node or edge. + + ## Usage + ```yaml + verb: unpack_graph + args: + type: node # The type of data to unpack, one of: node, edge. node will create a node list, edge will create an edge list + column: # The name of the column containing the graph, should be a graphml graph + ``` + """ + if copy is None: + copy = default_copy + input_df = input.get_input() + num_total = len(input_df) + result = [] + copy = [col for col in copy if col in input_df.columns] + has_embeddings = embeddings_column in input_df.columns + + for _, row in progress_iterable(input_df.iterrows(), callbacks.progress, num_total): + # merge the original row with the unpacked graph item + cleaned_row = {col: row[col] for col in copy} + embeddings = ( + cast(dict[str, list[float]], row[embeddings_column]) + if has_embeddings + else {} + ) + + result.extend([ + {**cleaned_row, **graph_id} + for graph_id in _run_unpack( + cast(str | nx.Graph, row[column]), + type, + embeddings, + kwargs, + ) + ]) + + output_df = pd.DataFrame(result) + return TableContainer(table=output_df) + + +def _run_unpack( + graphml_or_graph: str | nx.Graph, + unpack_type: str, + embeddings: dict[str, list[float]], + args: dict[str, Any], +) -> list[dict[str, Any]]: + graph = load_graph(graphml_or_graph) + if unpack_type == "nodes": + return _unpack_nodes(graph, embeddings, args) + if unpack_type == "edges": + return _unpack_edges(graph, args) + msg = f"Unknown type {unpack_type}" + raise ValueError(msg) + + +def _unpack_nodes( + graph: nx.Graph, embeddings: dict[str, list[float]], _args: dict[str, Any] +) -> list[dict[str, Any]]: + return [ + { + "label": label, + **(node_data or {}), + "graph_embedding": embeddings.get(label), + } + for label, node_data in graph.nodes(data=True) # type: ignore + ] + + +def _unpack_edges(graph: nx.Graph, _args: dict[str, Any]) -> list[dict[str, Any]]: + return [ + { + "source": source_id, + "target": target_id, + **(edge_data or {}), + } + for source_id, target_id, edge_data in graph.edges(data=True) # type: ignore + ] diff --git a/func-app/graphrag/index/verbs/overrides/__init__.py b/func-app/graphrag/index/verbs/overrides/__init__.py new file mode 100644 index 0000000000..24b82c1f3e --- /dev/null +++ b/func-app/graphrag/index/verbs/overrides/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine overrides package root.""" + +from .aggregate import aggregate +from .concat import concat +from .merge import merge + +__all__ = ["aggregate", "concat", "merge"] diff --git a/func-app/graphrag/index/verbs/overrides/aggregate.py b/func-app/graphrag/index/verbs/overrides/aggregate.py new file mode 100644 index 0000000000..df2137046b --- /dev/null +++ b/func-app/graphrag/index/verbs/overrides/aggregate.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Aggregation' model.""" + +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +from dataclasses import dataclass +from typing import Any, cast + +import pandas as pd +from datashaper import ( + FieldAggregateOperation, + Progress, + TableContainer, + VerbCallbacks, + VerbInput, + aggregate_operation_mapping, + verb, +) + +ARRAY_AGGREGATIONS = [ + FieldAggregateOperation.ArrayAgg, + FieldAggregateOperation.ArrayAggDistinct, +] + + +# TODO: This thing is kinda gross +# Also, it diverges from the original aggregate verb, since it doesn't support the same syntax +@verb(name="aggregate_override") +def aggregate( + input: VerbInput, + callbacks: VerbCallbacks, + aggregations: list[dict[str, Any]], + groupby: list[str] | None = None, + **_kwargs: dict, +) -> TableContainer: + """Aggregate method definition.""" + aggregations_to_apply = _load_aggregations(aggregations) + df_aggregations = { + agg.column: _get_pandas_agg_operation(agg) + for agg in aggregations_to_apply.values() + } + input_table = input.get_input() + callbacks.progress(Progress(percent=0)) + + if groupby is None: + output_grouped = input_table.groupby(lambda _x: True) + else: + output_grouped = input_table.groupby(groupby, sort=False) + output = cast(pd.DataFrame, output_grouped.agg(df_aggregations)) + output.rename( + columns={agg.column: agg.to for agg in aggregations_to_apply.values()}, + inplace=True, + ) + output.columns = [agg.to for agg in aggregations_to_apply.values()] + + callbacks.progress(Progress(percent=1)) + + return TableContainer(table=output.reset_index()) + + +@dataclass +class Aggregation: + """Aggregation class method definition.""" + + column: str | None + operation: str + to: str + + # Only useful for the concat operation + separator: str | None = None + + +def _get_pandas_agg_operation(agg: Aggregation) -> Any: + # TODO: Merge into datashaper + if agg.operation == "string_concat": + return (agg.separator or ",").join + return aggregate_operation_mapping[FieldAggregateOperation(agg.operation)] + + +def _load_aggregations( + aggregations: list[dict[str, Any]], +) -> dict[str, Aggregation]: + return { + aggregation["column"]: Aggregation( + aggregation["column"], aggregation["operation"], aggregation["to"] + ) + for aggregation in aggregations + } diff --git a/func-app/graphrag/index/verbs/overrides/concat.py b/func-app/graphrag/index/verbs/overrides/concat.py new file mode 100644 index 0000000000..7a0f0e2c32 --- /dev/null +++ b/func-app/graphrag/index/verbs/overrides/concat.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing concat method definition.""" + +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + + +@verb(name="concat_override") +def concat( + input: VerbInput, + columnwise: bool = False, + **_kwargs: dict, +) -> TableContainer: + """Concat method definition.""" + input_table = cast(pd.DataFrame, input.get_input()) + others = cast(list[pd.DataFrame], input.get_others()) + if columnwise: + output = pd.concat([input_table, *others], axis=1) + else: + output = pd.concat([input_table, *others], ignore_index=True) + return TableContainer(table=output) diff --git a/func-app/graphrag/index/verbs/overrides/merge.py b/func-app/graphrag/index/verbs/overrides/merge.py new file mode 100644 index 0000000000..64684c9828 --- /dev/null +++ b/func-app/graphrag/index/verbs/overrides/merge.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing merge and _merge_json methods definition.""" + +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +import logging +from enum import Enum +from typing import Any, cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, VerbResult, verb +from datashaper.engine.verbs.merge import merge as ds_merge + +log = logging.getLogger(__name__) + + +class MergeStrategyType(str, Enum): + """MergeStrategy class definition.""" + + json = "json" + datashaper = "datashaper" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +# TODO: This thing is kinda gross +# Also, it diverges from the original aggregate verb, since it doesn't support the same syntax +@verb(name="merge_override") +def merge( + input: VerbInput, + to: str, + columns: list[str], + strategy: MergeStrategyType = MergeStrategyType.datashaper, + delimiter: str = "", + preserveSource: bool = False, # noqa N806 + unhot: bool = False, + prefix: str = "", + **_kwargs: dict, +) -> TableContainer | VerbResult: + """Merge method definition.""" + output: pd.DataFrame + match strategy: + case MergeStrategyType.json: + output = _merge_json(input, to, columns) + filtered_list: list[str] = [] + + for col in output.columns: + try: + columns.index(col) + except ValueError: + log.exception("Column %s not found in input columns", col) + filtered_list.append(col) + + if not preserveSource: + output = cast(Any, output[filtered_list]) + return TableContainer(table=output.reset_index()) + case _: + return ds_merge( + input, to, columns, strategy, delimiter, preserveSource, unhot, prefix + ) + + +def _merge_json( + input: VerbInput, + to: str, + columns: list[str], +) -> pd.DataFrame: + input_table = cast(pd.DataFrame, input.get_input()) + output = input_table + output[to] = output[columns].apply( + lambda row: ({**row}), + axis=1, + ) + return output diff --git a/func-app/graphrag/index/verbs/snapshot.py b/func-app/graphrag/index/verbs/snapshot.py new file mode 100644 index 0000000000..b781478532 --- /dev/null +++ b/func-app/graphrag/index/verbs/snapshot.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing snapshot method definition.""" + +from datashaper import TableContainer, VerbInput, verb + +from graphrag.common.storage import PipelineStorage + + +@verb(name="snapshot") +async def snapshot( + input: VerbInput, + name: str, + formats: list[str], + storage: PipelineStorage, + **_kwargs: dict, +) -> TableContainer: + """Take a entire snapshot of the tabular data.""" + data = input.get_input() + + for fmt in formats: + if fmt == "parquet": + await storage.set(name + ".parquet", data.to_parquet()) + elif fmt == "json": + await storage.set( + name + ".json", data.to_json(orient="records", lines=True) + ) + + return TableContainer(table=data) diff --git a/func-app/graphrag/index/verbs/snapshot_rows.py b/func-app/graphrag/index/verbs/snapshot_rows.py new file mode 100644 index 0000000000..0b0ca1c3b6 --- /dev/null +++ b/func-app/graphrag/index/verbs/snapshot_rows.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'FormatSpecifier' model.""" + +import json +from dataclasses import dataclass +from typing import Any + +from datashaper import TableContainer, VerbInput, verb + +from graphrag.common.storage import PipelineStorage + + +@dataclass +class FormatSpecifier: + """Format specifier class definition.""" + + format: str + extension: str + + +@verb(name="snapshot_rows") +async def snapshot_rows( + input: VerbInput, + column: str | None, + base_name: str, + storage: PipelineStorage, + formats: list[str | dict[str, Any]], + row_name_column: str | None = None, + **_kwargs: dict, +) -> TableContainer: + """Take a by-row snapshot of the tabular data.""" + data = input.get_input() + parsed_formats = _parse_formats(formats) + num_rows = len(data) + + def get_row_name(row: Any, row_idx: Any): + if row_name_column is None: + if num_rows == 1: + return base_name + return f"{base_name}.{row_idx}" + return f"{base_name}.{row[row_name_column]}" + + for row_idx, row in data.iterrows(): + for fmt in parsed_formats: + row_name = get_row_name(row, row_idx) + extension = fmt.extension + if fmt.format == "json": + await storage.set( + f"{row_name}.{extension}", + json.dumps(row[column]) + if column is not None + else json.dumps(row.to_dict()), + ) + elif fmt.format == "text": + if column is None: + msg = "column must be specified for text format" + raise ValueError(msg) + await storage.set(f"{row_name}.{extension}", str(row[column])) + + return TableContainer(table=data) + + +def _parse_formats(formats: list[str | dict[str, Any]]) -> list[FormatSpecifier]: + """Parse the formats into a list of FormatSpecifiers.""" + return [ + FormatSpecifier(**fmt) + if isinstance(fmt, dict) + else FormatSpecifier(format=fmt, extension=_get_format_extension(fmt)) + for fmt in formats + ] + + +def _get_format_extension(fmt: str) -> str: + """Get the file extension for a given format.""" + if fmt == "json": + return "json" + if fmt == "text": + return "txt" + if fmt == "parquet": + return "parquet" + if fmt == "csv": + return "csv" + msg = f"Unknown format: {fmt}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/spread_json.py b/func-app/graphrag/index/verbs/spread_json.py new file mode 100644 index 0000000000..38656e12a4 --- /dev/null +++ b/func-app/graphrag/index/verbs/spread_json.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing spread_json method definition.""" + +import logging + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.utils import is_null + +# TODO: Check if this is already a thing +DEFAULT_COPY = ["level"] + + +@verb(name="spread_json") +def spread_json( + input: VerbInput, + column: str, + copy: list[str] | None = None, + **_kwargs: dict, +) -> TableContainer: + """ + Unpack a column containing a tuple into multiple columns. + + id|json|b + 1|{"x":5,"y":6}|b + + is converted to + + id|x|y|b + -------- + 1|5|6|b + """ + if copy is None: + copy = DEFAULT_COPY + data = input.get_input() + + results = [] + for _, row in data.iterrows(): + try: + cleaned_row = {col: row[col] for col in copy} + rest_row = row[column] if row[column] is not None else {} + + if is_null(rest_row): + rest_row = {} + + results.append({**cleaned_row, **rest_row}) # type: ignore + except Exception: + logging.exception("Error spreading row: %s", row) + raise + data = pd.DataFrame(results, index=data.index) + + return TableContainer(table=data) diff --git a/func-app/graphrag/index/verbs/text/__init__.py b/func-app/graphrag/index/verbs/text/__init__.py new file mode 100644 index 0000000000..032f45e1b1 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text package root.""" + +from .chunk.text_chunk import chunk +from .embed import text_embed +from .replace import replace +from .split import text_split +from .translate import text_translate + +__all__ = [ + "chunk", + "replace", + "text_embed", + "text_split", + "text_translate", +] diff --git a/func-app/graphrag/index/verbs/text/chunk/__init__.py b/func-app/graphrag/index/verbs/text/chunk/__init__.py new file mode 100644 index 0000000000..4e2a7729c5 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text chunk package root.""" + +from .text_chunk import ChunkStrategy, ChunkStrategyType, chunk + +__all__ = ["ChunkStrategy", "ChunkStrategyType", "chunk"] diff --git a/func-app/graphrag/index/verbs/text/chunk/strategies/__init__.py b/func-app/graphrag/index/verbs/text/chunk/strategies/__init__.py new file mode 100644 index 0000000000..0f15fcb2d5 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text chunk strategies package root.""" diff --git a/func-app/graphrag/index/verbs/text/chunk/strategies/sentence.py b/func-app/graphrag/index/verbs/text/chunk/strategies/sentence.py new file mode 100644 index 0000000000..687def1d90 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/strategies/sentence.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run method definition.""" + +from collections.abc import Iterable +from typing import Any + +import nltk +from datashaper import ProgressTicker + +from .typing import TextChunk + + +def run( + input: list[str], _args: dict[str, Any], tick: ProgressTicker +) -> Iterable[TextChunk]: + """Chunks text into multiple parts. A pipeline verb.""" + for doc_idx, text in enumerate(input): + sentences = nltk.sent_tokenize(text) + for sentence in sentences: + yield TextChunk( + text_chunk=sentence, + source_doc_indices=[doc_idx], + ) + tick(1) diff --git a/func-app/graphrag/index/verbs/text/chunk/strategies/tokens.py b/func-app/graphrag/index/verbs/text/chunk/strategies/tokens.py new file mode 100644 index 0000000000..6426c783e1 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/strategies/tokens.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and split_text_on_tokens methods definition.""" + +from collections.abc import Iterable +from typing import Any + +import tiktoken +from datashaper import ProgressTicker + +import graphrag.config.defaults as defs +from graphrag.index.text_splitting import Tokenizer +from graphrag.index.verbs.text.chunk.typing import TextChunk + + +def run( + input: list[str], args: dict[str, Any], tick: ProgressTicker +) -> Iterable[TextChunk]: + """Chunks text into multiple parts. A pipeline verb.""" + tokens_per_chunk = args.get("chunk_size", defs.CHUNK_SIZE) + chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP) + encoding_name = args.get("encoding_name", defs.ENCODING_MODEL) + enc = tiktoken.get_encoding(encoding_name) + + def encode(text: str) -> list[int]: + if not isinstance(text, str): + text = f"{text}" + return enc.encode(text) + + def decode(tokens: list[int]) -> str: + return enc.decode(tokens) + + return split_text_on_tokens( + input, + Tokenizer( + chunk_overlap=chunk_overlap, + tokens_per_chunk=tokens_per_chunk, + encode=encode, + decode=decode, + ), + tick, + ) + + +# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471 +# So we could have better control over the chunking process +def split_text_on_tokens( + texts: list[str], enc: Tokenizer, tick: ProgressTicker +) -> list[TextChunk]: + """Split incoming text and return chunks.""" + result = [] + mapped_ids = [] + + for source_doc_idx, text in enumerate(texts): + encoded = enc.encode(text) + tick(1) + mapped_ids.append((source_doc_idx, encoded)) + + input_ids: list[tuple[int, int]] = [ + (source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids + ] + + start_idx = 0 + cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + chunk_text = enc.decode([id for _, id in chunk_ids]) + doc_indices = list({doc_idx for doc_idx, _ in chunk_ids}) + result.append( + TextChunk( + text_chunk=chunk_text, + source_doc_indices=doc_indices, + n_tokens=len(chunk_ids), + ) + ) + start_idx += enc.tokens_per_chunk - enc.chunk_overlap + cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + + return result diff --git a/func-app/graphrag/index/verbs/text/chunk/strategies/typing.py b/func-app/graphrag/index/verbs/text/chunk/strategies/typing.py new file mode 100644 index 0000000000..b4e833c8e3 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/strategies/typing.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing ChunkStrategy definition.""" + +from collections.abc import Callable, Iterable +from typing import Any + +from datashaper import ProgressTicker + +from graphrag.index.verbs.text.chunk.typing import TextChunk + +# Given a list of document texts, return a list of tuples of (source_doc_indices, text_chunk) + +ChunkStrategy = Callable[ + [list[str], dict[str, Any], ProgressTicker], Iterable[TextChunk] +] diff --git a/func-app/graphrag/index/verbs/text/chunk/text_chunk.py b/func-app/graphrag/index/verbs/text/chunk/text_chunk.py new file mode 100644 index 0000000000..40c5578a0f --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/text_chunk.py @@ -0,0 +1,162 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing _get_num_total, chunk, run_strategy and load_strategy methods definitions.""" + +from enum import Enum +from typing import Any, cast + +import pandas as pd +from datashaper import ( + ProgressTicker, + TableContainer, + VerbCallbacks, + VerbInput, + progress_ticker, + verb, +) + +from .strategies.typing import ChunkStrategy as ChunkStrategy +from .typing import ChunkInput + + +def _get_num_total(output: pd.DataFrame, column: str) -> int: + num_total = 0 + for row in output[column]: + if isinstance(row, str): + num_total += 1 + else: + num_total += len(row) + return num_total + + +class ChunkStrategyType(str, Enum): + """ChunkStrategy class definition.""" + + tokens = "tokens" + sentence = "sentence" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="chunk") +def chunk( + input: VerbInput, + column: str, + to: str, + callbacks: VerbCallbacks, + strategy: dict[str, Any] | None = None, + **_kwargs, +) -> TableContainer: + """ + Chunk a piece of text into smaller pieces. + + ## Usage + ```yaml + verb: text_chunk + args: + column: # The name of the column containing the text to chunk, this can either be a column with text, or a column with a list[tuple[doc_id, str]] + to: # The name of the column to output the chunks to + strategy: # The strategy to use to chunk the text, see below for more details + ``` + + ## Strategies + The text chunk verb uses a strategy to chunk the text. The strategy is an object which defines the strategy to use. The following strategies are available: + + ### tokens + This strategy uses the [tokens] library to chunk a piece of text. The strategy config is as follows: + + > Note: In the future, this will likely be renamed to something more generic, like "openai_tokens". + + ```yaml + strategy: + type: tokens + chunk_size: 1200 # Optional, The chunk size to use, default: 1200 + chunk_overlap: 100 # Optional, The chunk overlap to use, default: 100 + ``` + + ### sentence + This strategy uses the nltk library to chunk a piece of text into sentences. The strategy config is as follows: + + ```yaml + strategy: + type: sentence + ``` + """ + if strategy is None: + strategy = {} + output = cast(pd.DataFrame, input.get_input()) + strategy_name = strategy.get("type", ChunkStrategyType.tokens) + strategy_config = {**strategy} + strategy_exec = load_strategy(strategy_name) + + num_total = _get_num_total(output, column) + tick = progress_ticker(callbacks.progress, num_total) + + output[to] = output.apply( + cast( + Any, + lambda x: run_strategy(strategy_exec, x[column], strategy_config, tick), + ), + axis=1, + ) + return TableContainer(table=output) + + +def run_strategy( + strategy: ChunkStrategy, + input: ChunkInput, + strategy_args: dict[str, Any], + tick: ProgressTicker, +) -> list[str | tuple[list[str] | None, str, int]]: + """Run strategy method definition.""" + if isinstance(input, str): + return [item.text_chunk for item in strategy([input], {**strategy_args}, tick)] + + # We can work with both just a list of text content + # or a list of tuples of (document_id, text content) + # text_to_chunk = ''' + texts = [] + for item in input: + if isinstance(item, str): + texts.append(item) + else: + texts.append(item[1]) + + strategy_results = strategy(texts, {**strategy_args}, tick) + + results = [] + for strategy_result in strategy_results: + doc_indices = strategy_result.source_doc_indices + if isinstance(input[doc_indices[0]], str): + results.append(strategy_result.text_chunk) + else: + doc_ids = [input[doc_idx][0] for doc_idx in doc_indices] + results.append(( + doc_ids, + strategy_result.text_chunk, + strategy_result.n_tokens, + )) + return results + + +def load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy: + """Load strategy method definition.""" + match strategy: + case ChunkStrategyType.tokens: + from .strategies.tokens import run as run_tokens + + return run_tokens + case ChunkStrategyType.sentence: + # NLTK + from graphrag.index.bootstrap import bootstrap + + from .strategies.sentence import run as run_sentence + + bootstrap() + return run_sentence + case _: + msg = f"Unknown strategy: {strategy}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/text/chunk/typing.py b/func-app/graphrag/index/verbs/text/chunk/typing.py new file mode 100644 index 0000000000..3a42cf68a7 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/typing.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'TextChunk' model.""" + +from dataclasses import dataclass + + +@dataclass +class TextChunk: + """Text chunk class definition.""" + + text_chunk: str + source_doc_indices: list[int] + n_tokens: int | None = None + + +ChunkInput = str | list[str] | list[tuple[str, str]] +"""Input to a chunking strategy. Can be a string, a list of strings, or a list of tuples of (id, text).""" diff --git a/func-app/graphrag/index/verbs/text/embed/__init__.py b/func-app/graphrag/index/verbs/text/embed/__init__.py new file mode 100644 index 0000000000..969bd2aab9 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text embed package root.""" + +from .text_embed import TextEmbedStrategyType, text_embed + +__all__ = ["TextEmbedStrategyType", "text_embed"] diff --git a/func-app/graphrag/index/verbs/text/embed/strategies/__init__.py b/func-app/graphrag/index/verbs/text/embed/strategies/__init__.py new file mode 100644 index 0000000000..8cbe7a580e --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine embed strategies package root.""" diff --git a/func-app/graphrag/index/verbs/text/embed/strategies/mock.py b/func-app/graphrag/index/verbs/text/embed/strategies/mock.py new file mode 100644 index 0000000000..1be4ab0f9f --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/strategies/mock.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _embed_text methods definitions.""" + +import random +from collections.abc import Iterable +from typing import Any + +from datashaper import ProgressTicker, VerbCallbacks, progress_ticker + +from graphrag.index.cache import PipelineCache + +from .typing import TextEmbeddingResult + + +async def run( # noqa RUF029 async is required for interface + input: list[str], + callbacks: VerbCallbacks, + cache: PipelineCache, + _args: dict[str, Any], +) -> TextEmbeddingResult: + """Run the Claim extraction chain.""" + input = input if isinstance(input, Iterable) else [input] + ticker = progress_ticker(callbacks.progress, len(input)) + return TextEmbeddingResult( + embeddings=[_embed_text(cache, text, ticker) for text in input] + ) + + +def _embed_text(_cache: PipelineCache, _text: str, tick: ProgressTicker) -> list[float]: + """Embed a single piece of text.""" + tick(1) + return [random.random(), random.random(), random.random()] # noqa S311 diff --git a/func-app/graphrag/index/verbs/text/embed/strategies/openai.py b/func-app/graphrag/index/verbs/text/embed/strategies/openai.py new file mode 100644 index 0000000000..fb443ec83e --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/strategies/openai.py @@ -0,0 +1,181 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run method definition.""" + +import asyncio +import logging +from typing import Any + +import numpy as np +from datashaper import ProgressTicker, VerbCallbacks, progress_ticker + +import graphrag.config.defaults as defs +from graphrag.index.cache import PipelineCache +from graphrag.index.llm import load_llm_embeddings +from graphrag.index.text_splitting import TokenTextSplitter +from graphrag.index.utils import is_null +from graphrag.llm import EmbeddingLLM, OpenAIConfiguration + +from .typing import TextEmbeddingResult + +log = logging.getLogger(__name__) + + +async def run( + input: list[str], + callbacks: VerbCallbacks, + cache: PipelineCache, + args: dict[str, Any], +) -> TextEmbeddingResult: + """Run the Claim extraction chain.""" + if is_null(input): + return TextEmbeddingResult(embeddings=None) + + llm_config = args.get("llm", {}) + batch_size = args.get("batch_size", 16) + batch_max_tokens = args.get("batch_max_tokens", 8191) + oai_config = OpenAIConfiguration(llm_config) + splitter = _get_splitter(oai_config, batch_max_tokens) + llm = _get_llm(oai_config, callbacks, cache) + semaphore: asyncio.Semaphore = asyncio.Semaphore(args.get("num_threads", 4)) + + # Break up the input texts. The sizes here indicate how many snippets are in each input text + texts, input_sizes = _prepare_embed_texts(input, splitter) + text_batches = _create_text_batches( + texts, + batch_size, + batch_max_tokens, + splitter, + ) + log.info( + "embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, max_tokens=%d", + len(input), + len(texts), + len(text_batches), + batch_size, + batch_max_tokens, + ) + ticker = progress_ticker(callbacks.progress, len(text_batches)) + + # Embed each chunk of snippets + embeddings = await _execute(llm, text_batches, ticker, semaphore) + embeddings = _reconstitute_embeddings(embeddings, input_sizes) + + return TextEmbeddingResult(embeddings=embeddings) + + +def _get_splitter( + config: OpenAIConfiguration, batch_max_tokens: int +) -> TokenTextSplitter: + return TokenTextSplitter( + encoding_name=config.encoding_model or defs.ENCODING_MODEL, + chunk_size=batch_max_tokens, + ) + + +def _get_llm( + config: OpenAIConfiguration, + callbacks: VerbCallbacks, + cache: PipelineCache, +) -> EmbeddingLLM: + llm_type = config.lookup("type", "Unknown") + return load_llm_embeddings( + "text_embedding", + llm_type, + callbacks, + cache, + config.raw_config, + ) + + +async def _execute( + llm: EmbeddingLLM, + chunks: list[list[str]], + tick: ProgressTicker, + semaphore: asyncio.Semaphore, +) -> list[list[float]]: + async def embed(chunk: list[str]): + async with semaphore: + chunk_embeddings = await llm(chunk) + result = np.array(chunk_embeddings.output) + tick(1) + return result + + futures = [embed(chunk) for chunk in chunks] + results = await asyncio.gather(*futures) + # merge results in a single list of lists (reduce the collect dimension) + return [item for sublist in results for item in sublist] + + +def _create_text_batches( + texts: list[str], + max_batch_size: int, + max_batch_tokens: int, + splitter: TokenTextSplitter, +) -> list[list[str]]: + """Create batches of texts to embed.""" + # https://learn.microsoft.com/en-us/azure/ai-services/openai/reference + # According to this embeddings reference, Azure limits us to 16 concurrent embeddings and 8191 tokens per request + result = [] + current_batch = [] + current_batch_tokens = 0 + + for text in texts: + token_count = splitter.num_tokens(text) + if ( + len(current_batch) >= max_batch_size + or current_batch_tokens + token_count > max_batch_tokens + ): + result.append(current_batch) + current_batch = [] + current_batch_tokens = 0 + + current_batch.append(text) + current_batch_tokens += token_count + + if len(current_batch) > 0: + result.append(current_batch) + + return result + + +def _prepare_embed_texts( + input: list[str], splitter: TokenTextSplitter +) -> tuple[list[str], list[int]]: + sizes: list[int] = [] + snippets: list[str] = [] + + for text in input: + # Split the input text and filter out any empty content + split_texts = splitter.split_text(text) + if split_texts is None: + continue + split_texts = [text for text in split_texts if len(text) > 0] + + sizes.append(len(split_texts)) + snippets.extend(split_texts) + + return snippets, sizes + + +def _reconstitute_embeddings( + raw_embeddings: list[list[float]], sizes: list[int] +) -> list[list[float] | None]: + """Reconstitute the embeddings into the original input texts.""" + embeddings: list[list[float] | None] = [] + cursor = 0 + for size in sizes: + if size == 0: + embeddings.append(None) + elif size == 1: + embedding = raw_embeddings[cursor] + embeddings.append(embedding) + cursor += 1 + else: + chunk = raw_embeddings[cursor : cursor + size] + average = np.average(chunk, axis=0) + normalized = average / np.linalg.norm(average) + embeddings.append(normalized.tolist()) + cursor += size + return embeddings diff --git a/func-app/graphrag/index/verbs/text/embed/strategies/typing.py b/func-app/graphrag/index/verbs/text/embed/strategies/typing.py new file mode 100644 index 0000000000..1b25256497 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/strategies/typing.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'TextEmbeddingResult' model.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + + +@dataclass +class TextEmbeddingResult: + """Text embedding result class definition.""" + + embeddings: list[list[float] | None] | None + + +TextEmbeddingStrategy = Callable[ + [ + list[str], + VerbCallbacks, + PipelineCache, + dict, + ], + Awaitable[TextEmbeddingResult], +] diff --git a/func-app/graphrag/index/verbs/text/embed/text_embed.py b/func-app/graphrag/index/verbs/text/embed/text_embed.py new file mode 100644 index 0000000000..cd9bbc798d --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/text_embed.py @@ -0,0 +1,269 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing text_embed, load_strategy and create_row_from_embedding_data methods definition.""" + +import logging +from enum import Enum +from typing import Any, cast + +import numpy as np +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, verb + +from graphrag.index.cache import PipelineCache +from graphrag.vector_stores import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreFactory, +) + +from .strategies.typing import TextEmbeddingStrategy + +log = logging.getLogger(__name__) + +# Per Azure OpenAI Limits +# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference +DEFAULT_EMBEDDING_BATCH_SIZE = 500 + + +class TextEmbedStrategyType(str, Enum): + """TextEmbedStrategyType class definition.""" + + openai = "openai" + mock = "mock" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="text_embed") +async def text_embed( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + column: str, + strategy: dict, + **kwargs, +) -> TableContainer: + """ + Embed a piece of text into a vector space. The verb outputs a new column containing a mapping between doc_id and vector. + + ## Usage + ```yaml + verb: text_embed + args: + column: text # The name of the column containing the text to embed, this can either be a column with text, or a column with a list[tuple[doc_id, str]] + to: embedding # The name of the column to output the embedding to + strategy: # See strategies section below + ``` + + ## Strategies + The text embed verb uses a strategy to embed the text. The strategy is an object which defines the strategy to use. The following strategies are available: + + ### openai + This strategy uses openai to embed a piece of text. In particular it uses a LLM to embed a piece of text. The strategy config is as follows: + + ```yaml + strategy: + type: openai + llm: # The configuration for the LLM + type: openai_embedding # the type of llm to use, available options are: openai_embedding, azure_openai_embedding + api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai + model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai + max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai + organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai + vector_store: # The optional configuration for the vector store + type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb, kusto + <...> + ``` + """ + vector_store_config = strategy.get("vector_store") + + if vector_store_config and not vector_store_config.get("index_in_memory"): + embedding_name = kwargs.get("embedding_name", "default") + vector_name = kwargs.get("vector_name", "vector") + collection_name = _get_collection_name(vector_store_config, embedding_name) + vector_name = _get_collection_name(vector_store_config, vector_name) + vector_store: BaseVectorStore = _create_vector_store( + vector_store_config, collection_name, vector_name, "reports" + ) + vector_store_workflow_config = vector_store_config.get( + embedding_name, vector_store_config + ) + return await _text_embed_with_vector_store( + input, + callbacks, + cache, + column, + strategy, + vector_store, + vector_store_workflow_config, + vector_store_config.get("store_in_table", False), + kwargs.get("to", f"{column}_embedding"), + ) + + return await _text_embed_in_memory( + input, + callbacks, + cache, + column, + strategy, + kwargs.get("to", f"{column}_embedding"), + ) + + +async def _text_embed_in_memory( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + column: str, + strategy: dict, + to: str, +): + output_df = cast(pd.DataFrame, input.get_input()) + strategy_type = strategy["type"] + strategy_exec = load_strategy(strategy_type) + strategy_args = {**strategy} + input_table = input.get_input() + + texts: list[str] = input_table[column].to_numpy().tolist() + result = await strategy_exec(texts, callbacks, cache, strategy_args) + + output_df[to] = result.embeddings + return TableContainer(table=output_df) + + +async def _text_embed_with_vector_store( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + column: str, + strategy: dict[str, Any], + vector_store: BaseVectorStore, + vector_store_config: dict, + store_in_table: bool = False, + to: str = "", +): + output_df = cast(pd.DataFrame, input.get_input()) + strategy_type = strategy["type"] + strategy_exec = load_strategy(strategy_type) + strategy_args = {**strategy} + + # Get vector-storage configuration + insert_batch_size: int = ( + vector_store_config.get("batch_size") or DEFAULT_EMBEDDING_BATCH_SIZE + ) + title_column: str = vector_store_config.get("title_column", "title") + id_column: str = vector_store_config.get("id_column", "id") + overwrite: bool = vector_store_config.get("overwrite", True) + + if column not in output_df.columns: + msg = f"Column {column} not found in input dataframe with columns {output_df.columns}" + raise ValueError(msg) + if title_column not in output_df.columns: + msg = f"Column {title_column} not found in input dataframe with columns {output_df.columns}" + raise ValueError(msg) + if id_column not in output_df.columns: + msg = f"Column {id_column} not found in input dataframe with columns {output_df.columns}" + raise ValueError(msg) + + total_rows = 0 + for row in output_df[column]: + if isinstance(row, list): + total_rows += len(row) + else: + total_rows += 1 + + i = 0 + starting_index = 0 + + all_results = [] + + while insert_batch_size * i < input.get_input().shape[0]: + batch = input.get_input().iloc[ + insert_batch_size * i : insert_batch_size * (i + 1) + ] + texts: list[str] = batch[column].to_numpy().tolist() + titles: list[str] = batch[title_column].to_numpy().tolist() + ids: list[str] = batch[id_column].to_numpy().tolist() + result = await strategy_exec( + texts, + callbacks, + cache, + strategy_args, + ) + if store_in_table and result.embeddings: + embeddings = [ + embedding for embedding in result.embeddings if embedding is not None + ] + all_results.extend(embeddings) + + vectors = result.embeddings or [] + documents: list[VectorStoreDocument] = [] + for id, text, title, vector in zip(ids, texts, titles, vectors, strict=True): + if type(vector) is np.ndarray: + vector = vector.tolist() + document = VectorStoreDocument( + id=id, + text=text, + vector=vector, + attributes={"title": title}, + ) + documents.append(document) + + vector_store.load_documents(documents, overwrite and i == 0) + starting_index += len(documents) + i += 1 + + if store_in_table: + output_df[to] = all_results + + return TableContainer(table=output_df) + + +def _create_vector_store( + vector_store_config: dict, collection_name: str, vector_name: str, reports_name: str, +) -> BaseVectorStore: + vector_store_type: str = str(vector_store_config.get("type")) + if collection_name: + vector_store_config.update({"collection_name": collection_name}) + if vector_name: + vector_store_config.update({"vector_name": vector_name}) + if reports_name: + vector_store_config.update({"reports_name": reports_name}) + + vector_store = VectorStoreFactory.get_vector_store( + vector_store_type, kwargs=vector_store_config + ) + + vector_store.connect(**vector_store_config) + return vector_store + + +def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str: + collection_name = vector_store_config.get("collection_name") + if not collection_name: + collection_names = vector_store_config.get("collection_names", {}) + collection_name = collection_names.get(embedding_name, embedding_name) + + msg = f"using {vector_store_config.get('type')} collection_name {collection_name} for embedding {embedding_name}" + log.info(msg) + return collection_name + + +def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy: + """Load strategy method definition.""" + match strategy: + case TextEmbedStrategyType.openai: + from .strategies.openai import run as run_openai + + return run_openai + case TextEmbedStrategyType.mock: + from .strategies.mock import run as run_mock + + return run_mock + case _: + msg = f"Unknown strategy: {strategy}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/text/replace/__init__.py b/func-app/graphrag/index/verbs/text/replace/__init__.py new file mode 100644 index 0000000000..f863415f40 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/replace/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text replace package root.""" + +from .replace import text_replace + +__all__ = ["text_replace"] diff --git a/func-app/graphrag/index/verbs/text/replace/replace.py b/func-app/graphrag/index/verbs/text/replace/replace.py new file mode 100644 index 0000000000..386fac3459 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/replace/replace.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing replace and _apply_replacements methods.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from .typing import Replacement + + +@verb(name="text_replace") +def text_replace( + input: VerbInput, + column: str, + to: str, + replacements: list[dict[str, str]], + **_kwargs: dict, +) -> TableContainer: + """ + Apply a set of replacements to a piece of text. + + ## Usage + ```yaml + verb: text_replace + args: + column: # The name of the column containing the text to replace + to: # The name of the column to write the replaced text to + replacements: # A list of replacements to apply + - pattern: # The regex pattern to find + replacement: # The string to replace with + ``` + """ + output = cast(pd.DataFrame, input.get_input()) + parsed_replacements = [Replacement(**r) for r in replacements] + output[to] = output[column].apply( + lambda text: _apply_replacements(text, parsed_replacements) + ) + return TableContainer(table=output) + + +def _apply_replacements(text: str, replacements: list[Replacement]) -> str: + for r in replacements: + text = text.replace(r.pattern, r.replacement) + return text diff --git a/func-app/graphrag/index/verbs/text/replace/typing.py b/func-app/graphrag/index/verbs/text/replace/typing.py new file mode 100644 index 0000000000..45beef9f28 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/replace/typing.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Replacement' model.""" + +from dataclasses import dataclass + + +@dataclass +class Replacement: + """Replacement class definition.""" + + pattern: str + replacement: str diff --git a/func-app/graphrag/index/verbs/text/split.py b/func-app/graphrag/index/verbs/text/split.py new file mode 100644 index 0000000000..b1339ff455 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/split.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the text_split method definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + + +@verb(name="text_split") +def text_split( + input: VerbInput, + column: str, + to: str, + separator: str = ",", + **_kwargs: dict, +) -> TableContainer: + """ + Split a piece of text into a list of strings based on a delimiter. The verb outputs a new column containing a list of strings. + + ## Usage + + ```yaml + verb: text_split + args: + column: text # The name of the column containing the text to split + to: split_text # The name of the column to output the split text to + separator: "," # The separator to split the text on, defaults to "," + ``` + """ + output = text_split_df(cast(pd.DataFrame, input.get_input()), column, to, separator) + return TableContainer(table=output) + + +def text_split_df( + input: pd.DataFrame, column: str, to: str, separator: str = "," +) -> pd.DataFrame: + """Split a column into a list of strings.""" + output = input + + def _apply_split(row): + if row[column] is None or isinstance(row[column], list): + return row[column] + if row[column] == "": + return [] + if not isinstance(row[column], str): + message = f"Expected {column} to be a string, but got {type(row[column])}" + raise TypeError(message) + return row[column].split(separator) + + output[to] = output.apply(_apply_split, axis=1) + return output diff --git a/func-app/graphrag/index/verbs/text/translate/__init__.py b/func-app/graphrag/index/verbs/text/translate/__init__.py new file mode 100644 index 0000000000..ad830dfa87 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text translate package root.""" + +from .text_translate import text_translate + +__all__ = ["text_translate"] diff --git a/func-app/graphrag/index/verbs/text/translate/strategies/__init__.py b/func-app/graphrag/index/verbs/text/translate/strategies/__init__.py new file mode 100644 index 0000000000..d418bbae28 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/strategies/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine translate strategies package root.""" + +from .mock import run as run_mock +from .openai import run as run_openai + +__all__ = ["run_mock", "run_openai"] diff --git a/func-app/graphrag/index/verbs/text/translate/strategies/defaults.py b/func-app/graphrag/index/verbs/text/translate/strategies/defaults.py new file mode 100644 index 0000000000..003e00eb1f --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/strategies/defaults.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing TRANSLATION_PROMPT value definition.""" + +TRANSLATION_PROMPT = """ + You are a helpful assistant. Translate into {language} the following text, and make sure all of the text is in {language}. + """.strip() diff --git a/func-app/graphrag/index/verbs/text/translate/strategies/mock.py b/func-app/graphrag/index/verbs/text/translate/strategies/mock.py new file mode 100644 index 0000000000..58a5a9995e --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/strategies/mock.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _summarize_text methods definitions.""" + +from typing import Any + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + +from .typing import TextTranslationResult + + +async def run( # noqa RUF029 async is required for interface + input: str | list[str], + _args: dict[str, Any], + _reporter: VerbCallbacks, + _cache: PipelineCache, +) -> TextTranslationResult: + """Run the Claim extraction chain.""" + input = [input] if isinstance(input, str) else input + return TextTranslationResult(translations=[_translate_text(text) for text in input]) + + +def _translate_text(text: str) -> str: + """Translate a single piece of text.""" + return f"{text} translated" diff --git a/func-app/graphrag/index/verbs/text/translate/strategies/openai.py b/func-app/graphrag/index/verbs/text/translate/strategies/openai.py new file mode 100644 index 0000000000..49c47b34a2 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/strategies/openai.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run, _translate_text and _create_translation_prompt methods definition.""" + +import logging +import traceback +from typing import Any + +from datashaper import VerbCallbacks + +import graphrag.config.defaults as defs +from graphrag.config.enums import LLMType +from graphrag.index.cache import PipelineCache +from graphrag.index.llm import load_llm +from graphrag.index.text_splitting import TokenTextSplitter +from graphrag.llm import CompletionLLM + +from .defaults import TRANSLATION_PROMPT as DEFAULT_TRANSLATION_PROMPT +from .typing import TextTranslationResult + +log = logging.getLogger(__name__) + + +async def run( + input: str | list[str], + args: dict[str, Any], + callbacks: VerbCallbacks, + pipeline_cache: PipelineCache, +) -> TextTranslationResult: + """Run the Claim extraction chain.""" + llm_config = args.get("llm", {"type": LLMType.StaticResponse}) + llm_type = llm_config.get("type", LLMType.StaticResponse) + llm = load_llm( + "text_translation", + llm_type, + callbacks, + pipeline_cache, + llm_config, + chat_only=True, + ) + language = args.get("language", "English") + prompt = args.get("prompt") + chunk_size = args.get("chunk_size", defs.CHUNK_SIZE) + chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP) + + input = [input] if isinstance(input, str) else input + return TextTranslationResult( + translations=[ + await _translate_text( + text, language, prompt, llm, chunk_size, chunk_overlap, callbacks + ) + for text in input + ] + ) + + +async def _translate_text( + text: str, + language: str, + prompt: str | None, + llm: CompletionLLM, + chunk_size: int, + chunk_overlap: int, + callbacks: VerbCallbacks, +) -> str: + """Translate a single piece of text.""" + splitter = TokenTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + out = "" + chunks = splitter.split_text(text) + for chunk in chunks: + try: + result = await llm( + chunk, + history=[ + { + "role": "system", + "content": (prompt or DEFAULT_TRANSLATION_PROMPT), + } + ], + variables={"language": language}, + ) + out += result.output or "" + except Exception as e: + log.exception("error translating text") + callbacks.error("Error translating text", e, traceback.format_exc()) + out += "" + + return out diff --git a/func-app/graphrag/index/verbs/text/translate/strategies/typing.py b/func-app/graphrag/index/verbs/text/translate/strategies/typing.py new file mode 100644 index 0000000000..d91ed735f5 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/strategies/typing.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'TextTranslationResult' model.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + + +@dataclass +class TextTranslationResult: + """Text translation result class definition.""" + + translations: list[str] + + +TextTranslationStrategy = Callable[ + [list[str], dict[str, Any], VerbCallbacks, PipelineCache], + Awaitable[TextTranslationResult], +] diff --git a/func-app/graphrag/index/verbs/text/translate/text_translate.py b/func-app/graphrag/index/verbs/text/translate/text_translate.py new file mode 100644 index 0000000000..8d0faffefa --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/text_translate.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing text_translate methods definition.""" + +from enum import Enum +from typing import Any, cast + +import pandas as pd +from datashaper import ( + AsyncType, + TableContainer, + VerbCallbacks, + VerbInput, + derive_from_rows, + verb, +) + +from graphrag.index.cache import PipelineCache + +from .strategies.typing import TextTranslationStrategy + + +class TextTranslateStrategyType(str, Enum): + """TextTranslateStrategyType class definition.""" + + openai = "openai" + mock = "mock" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="text_translate") +async def text_translate( + input: VerbInput, + cache: PipelineCache, + callbacks: VerbCallbacks, + text_column: str, + to: str, + strategy: dict[str, Any], + async_mode: AsyncType = AsyncType.AsyncIO, + **kwargs, +) -> TableContainer: + """ + Translate a piece of text into another language. + + ## Usage + ```yaml + verb: text_translate + args: + text_column: # The name of the column containing the text to translate + to: # The name of the column to write the translated text to + strategy: # The strategy to use to translate the text, see below for more details + ``` + + ## Strategies + The text translate verb uses a strategy to translate the text. The strategy is an object which defines the strategy to use. The following strategies are available: + + ### openai + This strategy uses openai to translate a piece of text. In particular it uses a LLM to translate a piece of text. The strategy config is as follows: + + ```yaml + strategy: + type: openai + language: english # The language to translate to, default: english + prompt: # The prompt to use for the translation, default: None + chunk_size: 2500 # The chunk size to use for the translation, default: 2500 + chunk_overlap: 0 # The chunk overlap to use for the translation, default: 0 + llm: # The configuration for the LLM + type: openai_chat # the type of llm to use, available options are: openai_chat, azure_openai_chat + api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai + model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai + max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai + organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai + ``` + """ + output_df = cast(pd.DataFrame, input.get_input()) + strategy_type = strategy["type"] + strategy_args = {**strategy} + strategy_exec = _load_strategy(strategy_type) + + async def run_strategy(row): + text = row[text_column] + result = await strategy_exec(text, strategy_args, callbacks, cache) + + # If it is a single string, then return just the translation for that string + if isinstance(text, str): + return result.translations[0] + + # Otherwise, return a list of translations, one for each item in the input + return list(result.translations) + + results = await derive_from_rows( + output_df, + run_strategy, + callbacks, + scheduling_type=async_mode, + num_threads=kwargs.get("num_threads", 4), + ) + output_df[to] = results + return TableContainer(table=output_df) + + +def _load_strategy(strategy: TextTranslateStrategyType) -> TextTranslationStrategy: + match strategy: + case TextTranslateStrategyType.openai: + from .strategies.openai import run as run_openai + + return run_openai + + case TextTranslateStrategyType.mock: + from .strategies.mock import run as run_mock + + return run_mock + + case _: + msg = f"Unknown strategy: {strategy}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/unzip.py b/func-app/graphrag/index/verbs/unzip.py new file mode 100644 index 0000000000..4d8c8da08e --- /dev/null +++ b/func-app/graphrag/index/verbs/unzip.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing unzip method definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + + +# TODO: Check if this is already a thing +# Takes 1|(x,y)|b +# and converts to +# 1|x|y|b +@verb(name="unzip") +def unzip( + input: VerbInput, column: str, to: list[str], **_kwargs: dict +) -> TableContainer: + """Unpacks a column containing a tuple into multiple columns.""" + table = cast(pd.DataFrame, input.get_input()) + + table[to] = pd.DataFrame(table[column].tolist(), index=table.index) + + return TableContainer(table=table) diff --git a/func-app/graphrag/index/verbs/zip.py b/func-app/graphrag/index/verbs/zip.py new file mode 100644 index 0000000000..462395d3da --- /dev/null +++ b/func-app/graphrag/index/verbs/zip.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing ds_zip method definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + + +@verb(name="zip") +def zip_verb( + input: VerbInput, + to: str, + columns: list[str], + type: str | None = None, # noqa A002 + **_kwargs: dict, +) -> TableContainer: + """ + Zip columns together. + + ## Usage + TODO + + """ + table = cast(pd.DataFrame, input.get_input()) + if type is None: + table[to] = list(zip(*[table[col] for col in columns], strict=True)) + + # This one is a little weird + elif type == "dict": + if len(columns) != 2: + msg = f"Expected exactly two columns for a dict, got {columns}" + raise ValueError(msg) + key_col, value_col = columns + + results = [] + for _, row in table.iterrows(): + keys = row[key_col] + values = row[value_col] + output = {} + if len(keys) != len(values): + msg = f"Expected same number of keys and values, got {len(keys)} keys and {len(values)} values" + raise ValueError(msg) + for idx, key in enumerate(keys): + output[key] = values[idx] + results.append(output) + + table[to] = results + return TableContainer(table=table.reset_index(drop=True)) diff --git a/func-app/graphrag/index/workflows/__init__.py b/func-app/graphrag/index/workflows/__init__.py new file mode 100644 index 0000000000..ed580309a8 --- /dev/null +++ b/func-app/graphrag/index/workflows/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine workflows package root.""" + +from .load import create_workflow, load_workflows +from .typing import ( + StepDefinition, + VerbDefinitions, + VerbTiming, + WorkflowConfig, + WorkflowDefinitions, + WorkflowToRun, +) + +__all__ = [ + "StepDefinition", + "VerbDefinitions", + "VerbTiming", + "WorkflowConfig", + "WorkflowDefinitions", + "WorkflowToRun", + "create_workflow", + "load_workflows", +] diff --git a/func-app/graphrag/index/workflows/default_workflows.py b/func-app/graphrag/index/workflows/default_workflows.py new file mode 100644 index 0000000000..81112bee32 --- /dev/null +++ b/func-app/graphrag/index/workflows/default_workflows.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing default workflows definitions.""" + +from .typing import WorkflowDefinitions +from .v1.create_base_documents import ( + build_steps as build_create_base_documents_steps, +) +from .v1.create_base_documents import ( + workflow_name as create_base_documents, +) +from .v1.create_base_entity_graph import ( + build_steps as build_create_base_entity_graph_steps, +) +from .v1.create_base_entity_graph import ( + workflow_name as create_base_entity_graph, +) +from .v1.create_base_extracted_entities import ( + build_steps as build_create_base_extracted_entities_steps, +) +from .v1.create_base_extracted_entities import ( + workflow_name as create_base_extracted_entities, +) +from .v1.create_base_text_units import ( + build_steps as build_create_base_text_units_steps, +) +from .v1.create_base_text_units import ( + workflow_name as create_base_text_units, +) +from .v1.create_final_communities import ( + build_steps as build_create_final_communities_steps, +) +from .v1.create_final_communities import ( + workflow_name as create_final_communities, +) +from .v1.create_final_community_reports import ( + build_steps as build_create_final_community_reports_steps, +) +from .v1.create_final_community_reports import ( + workflow_name as create_final_community_reports, +) +from .v1.create_final_covariates import ( + build_steps as build_create_final_covariates_steps, +) +from .v1.create_final_covariates import ( + workflow_name as create_final_covariates, +) +from .v1.create_final_documents import ( + build_steps as build_create_final_documents_steps, +) +from .v1.create_final_documents import ( + workflow_name as create_final_documents, +) +from .v1.create_final_entities import ( + build_steps as build_create_final_entities_steps, +) +from .v1.create_final_entities import ( + workflow_name as create_final_entities, +) +from .v1.create_final_nodes import ( + build_steps as build_create_final_nodes_steps, +) +from .v1.create_final_nodes import ( + workflow_name as create_final_nodes, +) +from .v1.create_final_relationships import ( + build_steps as build_create_final_relationships_steps, +) +from .v1.create_final_relationships import ( + workflow_name as create_final_relationships, +) +from .v1.create_final_text_units import ( + build_steps as build_create_final_text_units, +) +from .v1.create_final_text_units import ( + workflow_name as create_final_text_units, +) +from .v1.create_summarized_entities import ( + build_steps as build_create_summarized_entities_steps, +) +from .v1.create_summarized_entities import ( + workflow_name as create_summarized_entities, +) +from .v1.join_text_units_to_covariate_ids import ( + build_steps as join_text_units_to_covariate_ids_steps, +) +from .v1.join_text_units_to_covariate_ids import ( + workflow_name as join_text_units_to_covariate_ids, +) +from .v1.join_text_units_to_entity_ids import ( + build_steps as join_text_units_to_entity_ids_steps, +) +from .v1.join_text_units_to_entity_ids import ( + workflow_name as join_text_units_to_entity_ids, +) +from .v1.join_text_units_to_relationship_ids import ( + build_steps as join_text_units_to_relationship_ids_steps, +) +from .v1.join_text_units_to_relationship_ids import ( + workflow_name as join_text_units_to_relationship_ids, +) + +default_workflows: WorkflowDefinitions = { + create_base_extracted_entities: build_create_base_extracted_entities_steps, + create_base_entity_graph: build_create_base_entity_graph_steps, + create_base_text_units: build_create_base_text_units_steps, + create_final_text_units: build_create_final_text_units, + create_final_community_reports: build_create_final_community_reports_steps, + create_final_nodes: build_create_final_nodes_steps, + create_final_relationships: build_create_final_relationships_steps, + create_final_documents: build_create_final_documents_steps, + create_final_covariates: build_create_final_covariates_steps, + create_base_documents: build_create_base_documents_steps, + create_final_entities: build_create_final_entities_steps, + create_final_communities: build_create_final_communities_steps, + create_summarized_entities: build_create_summarized_entities_steps, + join_text_units_to_entity_ids: join_text_units_to_entity_ids_steps, + join_text_units_to_covariate_ids: join_text_units_to_covariate_ids_steps, + join_text_units_to_relationship_ids: join_text_units_to_relationship_ids_steps, +} diff --git a/func-app/graphrag/index/workflows/load.py b/func-app/graphrag/index/workflows/load.py new file mode 100644 index 0000000000..4dd6f9bfd0 --- /dev/null +++ b/func-app/graphrag/index/workflows/load.py @@ -0,0 +1,171 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load_workflows, create_workflow, _get_steps_for_workflow and _remove_disabled_steps methods definition.""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, NamedTuple, cast + +from datashaper import Workflow + +from graphrag.index.errors import ( + NoWorkflowsDefinedError, + UndefinedWorkflowError, + UnknownWorkflowError, +) +from graphrag.index.utils import topological_sort + +from .default_workflows import default_workflows as _default_workflows +from .typing import VerbDefinitions, WorkflowDefinitions, WorkflowToRun + +if TYPE_CHECKING: + from graphrag.index.config import ( + PipelineWorkflowConfig, + PipelineWorkflowReference, + PipelineWorkflowStep, + ) + +anonymous_workflow_count = 0 + +VerbFn = Callable[..., Any] +log = logging.getLogger(__name__) + + +class LoadWorkflowResult(NamedTuple): + """A workflow loading result object.""" + + workflows: list[WorkflowToRun] + """The loaded workflow names in the order they should be run.""" + + dependencies: dict[str, list[str]] + """A dictionary of workflow name to workflow dependencies.""" + + +def load_workflows( + workflows_to_load: list[PipelineWorkflowReference], + additional_verbs: VerbDefinitions | None = None, + additional_workflows: WorkflowDefinitions | None = None, + memory_profile: bool = False, +) -> LoadWorkflowResult: + """Load the given workflows. + + Args: + - workflows_to_load - The workflows to load + - additional_verbs - The list of custom verbs available to the workflows + - additional_workflows - The list of custom workflows + Returns: + - output[0] - The loaded workflow names in the order they should be run + - output[1] - A dictionary of workflow name to workflow dependencies + """ + workflow_graph: dict[str, WorkflowToRun] = {} + + global anonymous_workflow_count + for reference in workflows_to_load: + name = reference.name + is_anonymous = name is None or name.strip() == "" + if is_anonymous: + name = f"Anonymous Workflow {anonymous_workflow_count}" + anonymous_workflow_count += 1 + name = cast(str, name) + + config = reference.config + workflow = create_workflow( + name or "MISSING NAME!", + reference.steps, + config, + additional_verbs, + additional_workflows, + ) + workflow_graph[name] = WorkflowToRun(workflow, config=config or {}) + + # Backfill any missing workflows + for name in list(workflow_graph.keys()): + workflow = workflow_graph[name] + deps = [ + d.replace("workflow:", "") + for d in workflow.workflow.dependencies + if d.startswith("workflow:") + ] + for dependency in deps: + if dependency not in workflow_graph: + reference = {"name": dependency, **workflow.config} + workflow_graph[dependency] = WorkflowToRun( + workflow=create_workflow( + dependency, + config=reference, + additional_verbs=additional_verbs, + additional_workflows=additional_workflows, + memory_profile=memory_profile, + ), + config=reference, + ) + + # Run workflows in order of dependencies + def filter_wf_dependencies(name: str) -> list[str]: + externals = [ + e.replace("workflow:", "") + for e in workflow_graph[name].workflow.dependencies + ] + return [e for e in externals if e in workflow_graph] + + task_graph = {name: filter_wf_dependencies(name) for name in workflow_graph} + workflow_run_order = topological_sort(task_graph) + workflows = [workflow_graph[name] for name in workflow_run_order] + log.info("Workflow Run Order: %s", workflow_run_order) + return LoadWorkflowResult(workflows=workflows, dependencies=task_graph) + + +def create_workflow( + name: str, + steps: list[PipelineWorkflowStep] | None = None, + config: PipelineWorkflowConfig | None = None, + additional_verbs: VerbDefinitions | None = None, + additional_workflows: WorkflowDefinitions | None = None, + memory_profile: bool = False, +) -> Workflow: + """Create a workflow from the given config.""" + additional_workflows = { + **_default_workflows, + **(additional_workflows or {}), + } + steps = steps or _get_steps_for_workflow(name, config, additional_workflows) + steps = _remove_disabled_steps(steps) + return Workflow( + verbs=additional_verbs or {}, + schema={ + "name": name, + "steps": steps, + }, + validate=False, + memory_profile=memory_profile, + ) + + +def _get_steps_for_workflow( + name: str | None, + config: PipelineWorkflowConfig | None, + workflows: dict[str, Callable] | None, +) -> list[PipelineWorkflowStep]: + """Get the steps for the given workflow config.""" + if config is not None and "steps" in config: + return config["steps"] + + if workflows is None: + raise NoWorkflowsDefinedError + + if name is None: + raise UndefinedWorkflowError + + if name not in workflows: + raise UnknownWorkflowError(name) + + return workflows[name](config or {}) + + +def _remove_disabled_steps( + steps: list[PipelineWorkflowStep], +) -> list[PipelineWorkflowStep]: + return [step for step in steps if step.get("enabled", True)] diff --git a/func-app/graphrag/index/workflows/typing.py b/func-app/graphrag/index/workflows/typing.py new file mode 100644 index 0000000000..3b44545bd4 --- /dev/null +++ b/func-app/graphrag/index/workflows/typing.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'WorkflowToRun' model.""" + +from collections.abc import Callable +from dataclasses import dataclass as dc_dataclass +from typing import Any + +from datashaper import TableContainer, Workflow + +StepDefinition = dict[str, Any] +"""A step definition.""" + +VerbDefinitions = dict[str, Callable[..., TableContainer]] +"""A mapping of verb names to their implementations.""" + +WorkflowConfig = dict[str, Any] +"""A workflow configuration.""" + +WorkflowDefinitions = dict[str, Callable[[WorkflowConfig], list[StepDefinition]]] +"""A mapping of workflow names to their implementations.""" + +VerbTiming = dict[str, float] +"""The timings of verbs by id.""" + + +@dc_dataclass +class WorkflowToRun: + """Workflow to run class definition.""" + + workflow: Workflow + config: dict[str, Any] diff --git a/func-app/graphrag/index/workflows/v1/__init__.py b/func-app/graphrag/index/workflows/v1/__init__.py new file mode 100644 index 0000000000..69518f5ee2 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine workflows package root.""" diff --git a/func-app/graphrag/index/workflows/v1/create_base_documents.py b/func-app/graphrag/index/workflows/v1/create_base_documents.py new file mode 100644 index 0000000000..bd7094c64a --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_base_documents.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from datashaper import DEFAULT_INPUT_NAME + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_base_documents" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the documents table. + + ## Dependencies + * `workflow:create_final_text_units` + """ + document_attribute_columns = config.get("document_attribute_columns", []) + return [ + { + "verb": "unroll", + "args": {"column": "document_ids"}, + "input": {"source": "workflow:create_final_text_units"}, + }, + { + "verb": "select", + "args": { + # We only need the chunk id and the document id + "columns": ["id", "document_ids", "text"] + }, + }, + { + "id": "rename_chunk_doc_id", + "verb": "rename", + "args": { + "columns": { + "document_ids": "chunk_doc_id", + "id": "chunk_id", + "text": "chunk_text", + } + }, + }, + { + "verb": "join", + "args": { + # Join the doc id from the chunk onto the original document + "on": ["chunk_doc_id", "id"] + }, + "input": {"source": "rename_chunk_doc_id", "others": [DEFAULT_INPUT_NAME]}, + }, + { + "id": "docs_with_text_units", + "verb": "aggregate_override", + "args": { + "groupby": ["id"], + "aggregations": [ + { + "column": "chunk_id", + "operation": "array_agg", + "to": "text_units", + } + ], + }, + }, + { + "verb": "join", + "args": { + "on": ["id", "id"], + "strategy": "right outer", + }, + "input": { + "source": "docs_with_text_units", + "others": [DEFAULT_INPUT_NAME], + }, + }, + { + "verb": "rename", + "args": {"columns": {"text": "raw_content"}}, + }, + *[ + { + "verb": "convert", + "args": { + "column": column, + "to": column, + "type": "string", + }, + } + for column in document_attribute_columns + ], + { + "verb": "merge_override", + "enabled": len(document_attribute_columns) > 0, + "args": { + "columns": document_attribute_columns, + "strategy": "json", + "to": "attributes", + }, + }, + {"verb": "convert", "args": {"column": "id", "to": "id", "type": "string"}}, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_base_entity_graph.py b/func-app/graphrag/index/workflows/v1/create_base_entity_graph.py new file mode 100644 index 0000000000..b001aad218 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_base_entity_graph.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_base_entity_graph" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for the entity graph. + + ## Dependencies + * `workflow:create_base_extracted_entities` + """ + clustering_config = config.get( + "cluster_graph", + {"strategy": {"type": "leiden"}}, + ) + embed_graph_config = config.get( + "embed_graph", + { + "strategy": { + "type": "node2vec", + "num_walks": config.get("embed_num_walks", 10), + "walk_length": config.get("embed_walk_length", 40), + "window_size": config.get("embed_window_size", 2), + "iterations": config.get("embed_iterations", 3), + "random_seed": config.get("embed_random_seed", 86), + } + }, + ) + + graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False + embed_graph_enabled = config.get("embed_graph_enabled", False) or False + + return [ + { + "verb": "cluster_graph", + "args": { + **clustering_config, + "column": "entity_graph", + "to": "clustered_graph", + "level_to": "level", + }, + "input": ({"source": "workflow:create_summarized_entities"}), + }, + { + "verb": "snapshot_rows", + "enabled": graphml_snapshot_enabled, + "args": { + "base_name": "clustered_graph", + "column": "clustered_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + }, + { + "verb": "embed_graph", + "enabled": embed_graph_enabled, + "args": { + "column": "clustered_graph", + "to": "embeddings", + **embed_graph_config, + }, + }, + { + "verb": "snapshot_rows", + "enabled": graphml_snapshot_enabled, + "args": { + "base_name": "embedded_graph", + "column": "entity_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + }, + { + "verb": "select", + "args": { + # only selecting for documentation sake, so we know what is contained in + # this workflow + "columns": ( + ["level", "clustered_graph", "embeddings"] + if embed_graph_enabled + else ["level", "clustered_graph"] + ), + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_base_extracted_entities.py b/func-app/graphrag/index/workflows/v1/create_base_extracted_entities.py new file mode 100644 index 0000000000..30d608e9fd --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_base_extracted_entities.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from datashaper import AsyncType + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_base_extracted_entities" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for extracted entities. + + ## Dependencies + * `workflow:create_base_text_units` + """ + entity_extraction_config = config.get("entity_extract", {}) + graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False + raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False + + return [ + { + "verb": "entity_extract", + "args": { + **entity_extraction_config, + "column": entity_extraction_config.get("text_column", "chunk"), + "id_column": entity_extraction_config.get("id_column", "chunk_id"), + "async_mode": entity_extraction_config.get( + "async_mode", AsyncType.AsyncIO + ), + "to": "entities", + "graph_to": "entity_graph", + }, + "input": {"source": "workflow:create_base_text_units"}, + }, + { + "verb": "snapshot", + "enabled": raw_entity_snapshot_enabled, + "args": { + "name": "raw_extracted_entities", + "formats": ["json"], + }, + }, + { + "verb": "merge_graphs", + "args": { + "column": "entity_graph", + "to": "entity_graph", + **config.get( + "graph_merge_operations", + { + "nodes": { + "source_id": { + "operation": "concat", + "delimiter": ", ", + "distinct": True, + }, + "description": ({ + "operation": "concat", + "separator": "\n", + "distinct": False, + }), + }, + "edges": { + "source_id": { + "operation": "concat", + "delimiter": ", ", + "distinct": True, + }, + "description": ({ + "operation": "concat", + "separator": "\n", + "distinct": False, + }), + "weight": "sum", + }, + }, + ), + }, + }, + { + "verb": "snapshot_rows", + "enabled": graphml_snapshot_enabled, + "args": { + "base_name": "merged_graph", + "column": "entity_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_base_text_units.py b/func-app/graphrag/index/workflows/v1/create_base_text_units.py new file mode 100644 index 0000000000..63876e5e49 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_base_text_units.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from datashaper import DEFAULT_INPUT_NAME + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_base_text_units" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for text units. + + ## Dependencies + None + """ + chunk_column_name = config.get("chunk_column", "chunk") + chunk_by_columns = config.get("chunk_by", []) or [] + n_tokens_column_name = config.get("n_tokens_column", "n_tokens") + return [ + { + "verb": "orderby", + "args": { + "orders": [ + # sort for reproducibility + {"column": "id", "direction": "asc"}, + ] + }, + "input": {"source": DEFAULT_INPUT_NAME}, + }, + { + "verb": "zip", + "args": { + # Pack the document ids with the text + # So when we unpack the chunks, we can restore the document id + "columns": ["id", "text"], + "to": "text_with_ids", + }, + }, + { + "verb": "aggregate_override", + "args": { + "groupby": [*chunk_by_columns] if len(chunk_by_columns) > 0 else None, + "aggregations": [ + { + "column": "text_with_ids", + "operation": "array_agg", + "to": "texts", + } + ], + }, + }, + { + "verb": "chunk", + "args": {"column": "texts", "to": "chunks", **config.get("text_chunk", {})}, + }, + { + "verb": "select", + "args": { + "columns": [*chunk_by_columns, "chunks"], + }, + }, + { + "verb": "unroll", + "args": { + "column": "chunks", + }, + }, + { + "verb": "rename", + "args": { + "columns": { + "chunks": chunk_column_name, + } + }, + }, + { + "verb": "genid", + "args": { + # Generate a unique id for each chunk + "to": "chunk_id", + "method": "md5_hash", + "hash": [chunk_column_name], + }, + }, + { + "verb": "unzip", + "args": { + "column": chunk_column_name, + "to": ["document_ids", chunk_column_name, n_tokens_column_name], + }, + }, + {"verb": "copy", "args": {"column": "chunk_id", "to": "id"}}, + { + # ELIMINATE EMPTY CHUNKS + "verb": "filter", + "args": { + "column": chunk_column_name, + "criteria": [ + { + "type": "value", + "operator": "is not empty", + } + ], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_communities.py b/func-app/graphrag/index/workflows/v1/create_final_communities.py new file mode 100644 index 0000000000..f8949dfcec --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_communities.py @@ -0,0 +1,172 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_communities" + + +def build_steps( + _config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final communities table. + + ## Dependencies + * `workflow:create_base_entity_graph` + """ + return [ + { + "id": "graph_nodes", + "verb": "unpack_graph", + "args": { + "column": "clustered_graph", + "type": "nodes", + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, + { + "id": "graph_edges", + "verb": "unpack_graph", + "args": { + "column": "clustered_graph", + "type": "edges", + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, + { + "id": "source_clusters", + "verb": "join", + "args": { + "on": ["label", "source"], + }, + "input": {"source": "graph_nodes", "others": ["graph_edges"]}, + }, + { + "id": "target_clusters", + "verb": "join", + "args": { + "on": ["label", "target"], + }, + "input": {"source": "graph_nodes", "others": ["graph_edges"]}, + }, + { + "id": "concatenated_clusters", + "verb": "concat", + "input": { + "source": "source_clusters", + "others": ["target_clusters"], + }, + }, + { + "id": "combined_clusters", + "verb": "filter", + "args": { + # level_1 is the left side of the join + # level_2 is the right side of the join + "column": "level_1", + "criteria": [ + {"type": "column", "operator": "equals", "value": "level_2"} + ], + }, + "input": {"source": "concatenated_clusters"}, + }, + { + "id": "cluster_relationships", + "verb": "aggregate_override", + "args": { + "groupby": [ + "cluster", + "level_1", # level_1 is the left side of the join + ], + "aggregations": [ + { + "column": "id_2", # this is the id of the edge from the join steps above + "to": "relationship_ids", + "operation": "array_agg_distinct", + }, + { + "column": "source_id_1", + "to": "text_unit_ids", + "operation": "array_agg_distinct", + }, + ], + }, + "input": {"source": "combined_clusters"}, + }, + { + "id": "all_clusters", + "verb": "aggregate_override", + "args": { + "groupby": ["cluster", "level"], + "aggregations": [{"column": "cluster", "to": "id", "operation": "any"}], + }, + "input": {"source": "graph_nodes"}, + }, + { + "verb": "join", + "args": { + "on": ["id", "cluster"], + }, + "input": {"source": "all_clusters", "others": ["cluster_relationships"]}, + }, + { + "verb": "filter", + "args": { + # level is the left side of the join + # level_1 is the right side of the join + "column": "level", + "criteria": [ + {"type": "column", "operator": "equals", "value": "level_1"} + ], + }, + }, + *create_community_title_wf, + { + # TODO: Rodrigo says "raw_community" is temporary + "verb": "copy", + "args": { + "column": "id", + "to": "raw_community", + }, + }, + { + "verb": "select", + "args": { + "columns": [ + "id", + "title", + "level", + "raw_community", + "relationship_ids", + "text_unit_ids", + ], + }, + }, + ] + + +create_community_title_wf = [ + # Hack to string concat "Community " + id + { + "verb": "fill", + "args": { + "to": "__temp", + "value": "Community ", + }, + }, + { + "verb": "merge", + "args": { + "columns": [ + "__temp", + "id", + ], + "to": "title", + "strategy": "concat", + "preserveSource": True, + }, + }, +] diff --git a/func-app/graphrag/index/workflows/v1/create_final_community_reports.py b/func-app/graphrag/index/workflows/v1/create_final_community_reports.py new file mode 100644 index 0000000000..164c70e0dd --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_community_reports.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_community_reports" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final community reports table. + + ## Dependencies + * `workflow:create_base_entity_graph` + """ + covariates_enabled = config.get("covariates_enabled", False) + create_community_reports_config = config.get("create_community_reports", {}) + base_text_embed = config.get("text_embed", {}) + community_report_full_content_embed_config = config.get( + "community_report_full_content_embed", base_text_embed + ) + community_report_summary_embed_config = config.get( + "community_report_summary_embed", base_text_embed + ) + community_report_title_embed_config = config.get( + "community_report_title_embed", base_text_embed + ) + skip_title_embedding = config.get("skip_title_embedding", False) + skip_summary_embedding = config.get("skip_summary_embedding", False) + skip_full_content_embedding = config.get("skip_full_content_embedding", False) + + return [ + # + # Subworkflow: Prepare Nodes + # + { + "id": "nodes", + "verb": "prepare_community_reports_nodes", + "input": {"source": "workflow:create_final_nodes"}, + }, + # + # Subworkflow: Prepare Edges + # + { + "id": "edges", + "verb": "prepare_community_reports_edges", + "input": {"source": "workflow:create_final_relationships"}, + }, + # + # Subworkflow: Prepare Claims Table + # + { + "id": "claims", + "enabled": covariates_enabled, + "verb": "prepare_community_reports_claims", + "input": { + "source": "workflow:create_final_covariates", + } + if covariates_enabled + else {}, + }, + # + # Subworkflow: Get Community Hierarchy + # + { + "id": "community_hierarchy", + "verb": "restore_community_hierarchy", + "input": {"source": "nodes"}, + }, + # + # Main Workflow: Create Community Reports + # + { + "id": "local_contexts", + "verb": "prepare_community_reports", + "input": { + "source": "nodes", + "nodes": "nodes", + "edges": "edges", + **({"claims": "claims"} if covariates_enabled else {}), + }, + }, + { + "verb": "create_community_reports", + "args": { + **create_community_reports_config, + }, + "input": { + "source": "local_contexts", + "community_hierarchy": "community_hierarchy", + "nodes": "nodes", + }, + }, + { + # Generate a unique ID for each community report distinct from the community ID + "verb": "window", + "args": {"to": "id", "operation": "uuid", "column": "community"}, + }, + { + "verb": "text_embed", + "enabled": not skip_full_content_embedding, + "args": { + "embedding_name": "community_report_full_content", + "column": "full_content", + "to": "full_content_embedding", + **community_report_full_content_embed_config, + }, + }, + { + "verb": "text_embed", + "enabled": not skip_summary_embedding, + "args": { + "embedding_name": "community_report_summary", + "column": "summary", + "to": "summary_embedding", + **community_report_summary_embed_config, + }, + }, + { + "verb": "text_embed", + "enabled": not skip_title_embedding, + "args": { + "embedding_name": "community_report_title", + "column": "title", + "to": "title_embedding", + **community_report_title_embed_config, + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_covariates.py b/func-app/graphrag/index/workflows/v1/create_final_covariates.py new file mode 100644 index 0000000000..d1090e5054 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_covariates.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from datashaper import AsyncType + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_covariates" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final covariates table. + + ## Dependencies + * `workflow:create_base_text_units` + * `workflow:create_base_extracted_entities` + """ + claim_extract_config = config.get("claim_extract", {}) + + input = {"source": "workflow:create_base_text_units"} + + return [ + { + "verb": "extract_covariates", + "args": { + "column": config.get("chunk_column", "chunk"), + "id_column": config.get("chunk_id_column", "chunk_id"), + "resolved_entities_column": "resolved_entities", + "covariate_type": "claim", + "async_mode": config.get("async_mode", AsyncType.AsyncIO), + **claim_extract_config, + }, + "input": input, + }, + { + "verb": "window", + "args": {"to": "id", "operation": "uuid", "column": "covariate_type"}, + }, + { + "verb": "genid", + "args": { + "to": "human_readable_id", + "method": "increment", + }, + }, + { + "verb": "convert", + "args": { + "column": "human_readable_id", + "type": "string", + "to": "human_readable_id", + }, + }, + { + "verb": "rename", + "args": { + "columns": { + "chunk_id": "text_unit_id", + } + }, + }, + { + "verb": "select", + "args": { + "columns": [ + "id", + "human_readable_id", + "covariate_type", + "type", + "description", + "subject_id", + "subject_type", + "object_id", + "object_type", + "status", + "start_date", + "end_date", + "source_text", + "text_unit_id", + "document_ids", + "n_tokens", + ] + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_documents.py b/func-app/graphrag/index/workflows/v1/create_final_documents.py new file mode 100644 index 0000000000..d09ce001b0 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_documents.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_documents" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final documents table. + + ## Dependencies + * `workflow:create_base_documents` + * `workflow:create_base_document_nodes` + """ + base_text_embed = config.get("text_embed", {}) + document_raw_content_embed_config = config.get( + "document_raw_content_embed", base_text_embed + ) + skip_raw_content_embedding = config.get("skip_raw_content_embedding", False) + return [ + { + "verb": "rename", + "args": {"columns": {"text_units": "text_unit_ids"}}, + "input": {"source": "workflow:create_base_documents"}, + }, + { + "verb": "text_embed", + "enabled": not skip_raw_content_embedding, + "args": { + "column": "raw_content", + "to": "raw_content_embedding", + **document_raw_content_embed_config, + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_entities.py b/func-app/graphrag/index/workflows/v1/create_final_entities.py new file mode 100644 index 0000000000..9d8b962b77 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_entities.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_entities" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final entities table. + + ## Dependencies + * `workflow:create_base_entity_graph` + """ + base_text_embed = config.get("text_embed", {}) + entity_name_embed_config = config.get("entity_name_embed", base_text_embed) + entity_name_description_embed_config = config.get( + "entity_name_description_embed", base_text_embed + ) + skip_name_embedding = config.get("skip_name_embedding", False) + skip_description_embedding = config.get("skip_description_embedding", False) + is_using_vector_store = ( + entity_name_embed_config.get("strategy", {}).get("vector_store", None) + is not None + ) + + return [ + { + "verb": "unpack_graph", + "args": { + "column": "clustered_graph", + "type": "nodes", + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, + {"verb": "rename", "args": {"columns": {"label": "title"}}}, + { + "verb": "select", + "args": { + "columns": [ + "id", + "title", + "type", + "description", + "human_readable_id", + "graph_embedding", + "source_id", + ], + }, + }, + { + # create_base_entity_graph has multiple levels of clustering, which means there are multiple graphs with the same entities + # this dedupes the entities so that there is only one of each entity + "verb": "dedupe", + "args": {"columns": ["id"]}, + }, + {"verb": "rename", "args": {"columns": {"title": "name"}}}, + { + # ELIMINATE EMPTY NAMES + "verb": "filter", + "args": { + "column": "name", + "criteria": [ + { + "type": "value", + "operator": "is not empty", + } + ], + }, + }, + { + "verb": "text_split", + "args": {"separator": ",", "column": "source_id", "to": "text_unit_ids"}, + }, + {"verb": "drop", "args": {"columns": ["source_id"]}}, + { + "verb": "text_embed", + "enabled": not skip_name_embedding, + "args": { + "embedding_name": "entity_name", + "column": "name", + "to": "name_embedding", + **entity_name_embed_config, + }, + }, + { + "verb": "merge", + "enabled": not skip_description_embedding, + "args": { + "strategy": "concat", + "columns": ["name", "description"], + "to": "name_description", + "delimiter": ":", + "preserveSource": True, + }, + }, + { + "verb": "text_embed", + "enabled": not skip_description_embedding, + "args": { + "embedding_name": "entity_name_description", + "column": "name_description", + "to": "description_embedding", + **entity_name_description_embed_config, + }, + }, + { + "verb": "drop", + "enabled": not skip_description_embedding, + "args": { + "columns": ["name_description"], + }, + }, + { + # ELIMINATE EMPTY DESCRIPTION EMBEDDINGS + "verb": "filter", + "enabled": not skip_description_embedding and not is_using_vector_store, + "args": { + "column": "description_embedding", + "criteria": [ + { + "type": "value", + "operator": "is not empty", + } + ], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_nodes.py b/func-app/graphrag/index/workflows/v1/create_final_nodes.py new file mode 100644 index 0000000000..31277e7bf0 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_nodes.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_nodes" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for the document graph. + + ## Dependencies + * `workflow:create_base_entity_graph` + """ + snapshot_top_level_nodes = config.get("snapshot_top_level_nodes", False) + layout_graph_enabled = config.get("layout_graph_enabled", True) + _compute_top_level_node_positions = [ + { + "verb": "unpack_graph", + "args": {"column": "positioned_graph", "type": "nodes"}, + "input": {"source": "laid_out_entity_graph"}, + }, + { + "verb": "filter", + "args": { + "column": "level", + "criteria": [ + { + "type": "value", + "operator": "equals", + "value": config.get("level_for_node_positions", 0), + } + ], + }, + }, + { + "verb": "select", + "args": {"columns": ["id", "x", "y"]}, + }, + { + "verb": "snapshot", + "enabled": snapshot_top_level_nodes, + "args": { + "name": "top_level_nodes", + "formats": ["json"], + }, + }, + { + "id": "_compute_top_level_node_positions", + "verb": "rename", + "args": { + "columns": { + "id": "top_level_node_id", + } + }, + }, + { + "verb": "convert", + "args": { + "column": "top_level_node_id", + "to": "top_level_node_id", + "type": "string", + }, + }, + ] + layout_graph_config = config.get( + "layout_graph", + { + "strategy": { + "type": "umap" if layout_graph_enabled else "zero", + }, + }, + ) + return [ + { + "id": "laid_out_entity_graph", + "verb": "layout_graph", + "args": { + "embeddings_column": "embeddings", + "graph_column": "clustered_graph", + "to": "node_positions", + "graph_to": "positioned_graph", + **layout_graph_config, + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, + { + "verb": "unpack_graph", + "args": {"column": "positioned_graph", "type": "nodes"}, + }, + { + "id": "nodes_without_positions", + "verb": "drop", + "args": {"columns": ["x", "y"]}, + }, + *_compute_top_level_node_positions, + { + "verb": "join", + "args": { + "on": ["id", "top_level_node_id"], + }, + "input": { + "source": "nodes_without_positions", + "others": ["_compute_top_level_node_positions"], + }, + }, + { + "verb": "rename", + "args": {"columns": {"label": "title", "cluster": "community"}}, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_relationships.py b/func-app/graphrag/index/workflows/v1/create_final_relationships.py new file mode 100644 index 0000000000..a58c2a45b4 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_relationships.py @@ -0,0 +1,94 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_relationships" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final relationships table. + + ## Dependencies + * `workflow:create_base_entity_graph` + """ + base_text_embed = config.get("text_embed", {}) + relationship_description_embed_config = config.get( + "relationship_description_embed", base_text_embed + ) + skip_description_embedding = config.get("skip_description_embedding", False) + + return [ + { + "verb": "unpack_graph", + "args": { + "column": "clustered_graph", + "type": "edges", + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, + { + "verb": "rename", + "args": {"columns": {"source_id": "text_unit_ids"}}, + }, + { + "verb": "filter", + "args": { + "column": "level", + "criteria": [{"type": "value", "operator": "equals", "value": 0}], + }, + }, + { + "verb": "text_embed", + "enabled": not skip_description_embedding, + "args": { + "embedding_name": "relationship_description", + "column": "description", + "to": "description_embedding", + **relationship_description_embed_config, + }, + }, + { + "id": "pruned_edges", + "verb": "drop", + "args": {"columns": ["level"]}, + }, + { + "id": "filtered_nodes", + "verb": "filter", + "args": { + "column": "level", + "criteria": [{"type": "value", "operator": "equals", "value": 0}], + }, + "input": "workflow:create_final_nodes", + }, + { + "verb": "compute_edge_combined_degree", + "args": {"to": "rank"}, + "input": { + "source": "pruned_edges", + "nodes": "filtered_nodes", + }, + }, + { + "verb": "convert", + "args": { + "column": "human_readable_id", + "type": "string", + "to": "human_readable_id", + }, + }, + { + "verb": "convert", + "args": { + "column": "text_unit_ids", + "type": "array", + "to": "text_unit_ids", + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_text_units.py b/func-app/graphrag/index/workflows/v1/create_final_text_units.py new file mode 100644 index 0000000000..56dd0a73d6 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_text_units.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_text_units" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final text-units table. + + ## Dependencies + * `workflow:create_base_text_units` + * `workflow:create_final_entities` + * `workflow:create_final_communities` + """ + base_text_embed = config.get("text_embed", {}) + text_unit_text_embed_config = config.get("text_unit_text_embed", base_text_embed) + covariates_enabled = config.get("covariates_enabled", False) + skip_text_unit_embedding = config.get("skip_text_unit_embedding", False) + is_using_vector_store = ( + text_unit_text_embed_config.get("strategy", {}).get("vector_store", None) + is not None + ) + + return [ + { + "verb": "select", + "args": {"columns": ["id", "chunk", "document_ids", "n_tokens"]}, + "input": {"source": "workflow:create_base_text_units"}, + }, + { + "id": "pre_entity_join", + "verb": "rename", + "args": { + "columns": { + "chunk": "text", + }, + }, + }, + # Expand the TextUnits with EntityIDs + { + "id": "pre_relationship_join", + "verb": "join", + "args": { + "on": ["id", "id"], + "strategy": "left outer", + }, + "input": { + "source": "pre_entity_join", + "others": ["workflow:join_text_units_to_entity_ids"], + }, + }, + # Expand the TextUnits with RelationshipIDs + { + "id": "pre_covariate_join", + "verb": "join", + "args": { + "on": ["id", "id"], + "strategy": "left outer", + }, + "input": { + "source": "pre_relationship_join", + "others": ["workflow:join_text_units_to_relationship_ids"], + }, + }, + # Expand the TextUnits with CovariateIDs + { + "enabled": covariates_enabled, + "verb": "join", + "args": { + "on": ["id", "id"], + "strategy": "left outer", + }, + "input": { + "source": "pre_covariate_join", + "others": ["workflow:join_text_units_to_covariate_ids"], + }, + }, + # Mash the entities and relationships into arrays + { + "verb": "aggregate_override", + "args": { + "groupby": ["id"], # from the join above + "aggregations": [ + { + "column": "text", + "operation": "any", + "to": "text", + }, + { + "column": "n_tokens", + "operation": "any", + "to": "n_tokens", + }, + { + "column": "document_ids", + "operation": "any", + "to": "document_ids", + }, + { + "column": "entity_ids", + "operation": "any", + "to": "entity_ids", + }, + { + "column": "relationship_ids", + "operation": "any", + "to": "relationship_ids", + }, + *( + [] + if not covariates_enabled + else [ + { + "column": "covariate_ids", + "operation": "any", + "to": "covariate_ids", + } + ] + ), + ], + }, + }, + # Text-Embed after final aggregations + { + "id": "embedded_text_units", + "verb": "text_embed", + "enabled": not skip_text_unit_embedding, + "args": { + "column": config.get("column", "text"), + "to": config.get("to", "text_embedding"), + **text_unit_text_embed_config, + }, + }, + { + "verb": "select", + "args": { + # Final select to get output in the correct shape + "columns": [ + "id", + "text", + *( + [] + if (skip_text_unit_embedding or is_using_vector_store) + else ["text_embedding"] + ), + "n_tokens", + "document_ids", + "entity_ids", + "relationship_ids", + *([] if not covariates_enabled else ["covariate_ids"]), + ], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_summarized_entities.py b/func-app/graphrag/index/workflows/v1/create_summarized_entities.py new file mode 100644 index 0000000000..8f8d7f0042 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_summarized_entities.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from datashaper import AsyncType + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_summarized_entities" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for extracted entities. + + ## Dependencies + * `workflow:create_base_text_units` + """ + summarize_descriptions_config = config.get("summarize_descriptions", {}) + graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False + + return [ + { + "verb": "summarize_descriptions", + "args": { + **summarize_descriptions_config, + "column": "entity_graph", + "to": "entity_graph", + "async_mode": summarize_descriptions_config.get( + "async_mode", AsyncType.AsyncIO + ), + }, + "input": {"source": "workflow:create_base_extracted_entities"}, + }, + { + "verb": "snapshot_rows", + "enabled": graphml_snapshot_enabled, + "args": { + "base_name": "summarized_graph", + "column": "entity_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/join_text_units_to_covariate_ids.py b/func-app/graphrag/index/workflows/v1/join_text_units_to_covariate_ids.py new file mode 100644 index 0000000000..be6bddf1e4 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/join_text_units_to_covariate_ids.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "join_text_units_to_covariate_ids" + + +def build_steps( + _config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final text-units table. + + ## Dependencies + * `workflow:create_final_covariates` + """ + return [ + { + "verb": "select", + "args": {"columns": ["id", "text_unit_id"]}, + "input": {"source": "workflow:create_final_covariates"}, + }, + { + "verb": "aggregate_override", + "args": { + "groupby": ["text_unit_id"], + "aggregations": [ + { + "column": "id", + "operation": "array_agg_distinct", + "to": "covariate_ids", + }, + { + "column": "text_unit_id", + "operation": "any", + "to": "id", + }, + ], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/join_text_units_to_entity_ids.py b/func-app/graphrag/index/workflows/v1/join_text_units_to_entity_ids.py new file mode 100644 index 0000000000..6337502d1a --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/join_text_units_to_entity_ids.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "join_text_units_to_entity_ids" + + +def build_steps( + _config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create a join table from text unit ids to entity ids. + + ## Dependencies + * `workflow:create_final_entities` + """ + return [ + { + "verb": "select", + "args": {"columns": ["id", "text_unit_ids"]}, + "input": {"source": "workflow:create_final_entities"}, + }, + { + "verb": "unroll", + "args": { + "column": "text_unit_ids", + }, + }, + { + "verb": "aggregate_override", + "args": { + "groupby": ["text_unit_ids"], + "aggregations": [ + { + "column": "id", + "operation": "array_agg_distinct", + "to": "entity_ids", + }, + { + "column": "text_unit_ids", + "operation": "any", + "to": "id", + }, + ], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/join_text_units_to_relationship_ids.py b/func-app/graphrag/index/workflows/v1/join_text_units_to_relationship_ids.py new file mode 100644 index 0000000000..fe6d6463be --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/join_text_units_to_relationship_ids.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "join_text_units_to_relationship_ids" + + +def build_steps( + _config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create a join table from text unit ids to relationship ids. + + ## Dependencies + * `workflow:create_final_relationships + """ + return [ + { + "verb": "select", + "args": {"columns": ["id", "text_unit_ids"]}, + "input": {"source": "workflow:create_final_relationships"}, + }, + { + "verb": "unroll", + "args": { + "column": "text_unit_ids", + }, + }, + { + "verb": "aggregate_override", + "args": { + "groupby": ["text_unit_ids"], + "aggregations": [ + { + "column": "id", + "operation": "array_agg_distinct", + "to": "relationship_ids", + }, + { + "column": "text_unit_ids", + "operation": "any", + "to": "id", + }, + ], + }, + }, + { + "id": "text_unit_id_to_relationship_ids", + "verb": "select", + "args": {"columns": ["id", "relationship_ids"]}, + }, + ] diff --git a/func-app/graphrag/llm/__init__.py b/func-app/graphrag/llm/__init__.py new file mode 100644 index 0000000000..609be951b2 --- /dev/null +++ b/func-app/graphrag/llm/__init__.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Datashaper OpenAI Utilities package.""" + +from .base import BaseLLM, CachingLLM, RateLimitingLLM +from .errors import RetriesExhaustedError +from .limiting import ( + CompositeLLMLimiter, + LLMLimiter, + NoopLLMLimiter, + TpmRpmLLMLimiter, + create_tpm_rpm_limiters, +) +from .mock import MockChatLLM, MockCompletionLLM +from .openai import ( + OpenAIChatLLM, + OpenAIClientTypes, + OpenAICompletionLLM, + OpenAIConfiguration, + OpenAIEmbeddingsLLM, + create_openai_chat_llm, + create_openai_client, + create_openai_completion_llm, + create_openai_embedding_llm, +) +from .types import ( + LLM, + CompletionInput, + CompletionLLM, + CompletionOutput, + EmbeddingInput, + EmbeddingLLM, + EmbeddingOutput, + ErrorHandlerFn, + IsResponseValidFn, + LLMCache, + LLMConfig, + LLMInput, + LLMInvocationFn, + LLMInvocationResult, + LLMOutput, + OnCacheActionFn, +) + +__all__ = [ + # LLM Types + "LLM", + "BaseLLM", + "CachingLLM", + "CompletionInput", + "CompletionLLM", + "CompletionOutput", + "CompositeLLMLimiter", + "EmbeddingInput", + "EmbeddingLLM", + "EmbeddingOutput", + # Callbacks + "ErrorHandlerFn", + "IsResponseValidFn", + # Cache + "LLMCache", + "LLMConfig", + # LLM I/O Types + "LLMInput", + "LLMInvocationFn", + "LLMInvocationResult", + "LLMLimiter", + "LLMOutput", + "MockChatLLM", + # Mock + "MockCompletionLLM", + "NoopLLMLimiter", + "OnCacheActionFn", + "OpenAIChatLLM", + "OpenAIClientTypes", + "OpenAICompletionLLM", + # OpenAI + "OpenAIConfiguration", + "OpenAIEmbeddingsLLM", + "RateLimitingLLM", + # Errors + "RetriesExhaustedError", + "TpmRpmLLMLimiter", + "create_openai_chat_llm", + "create_openai_client", + "create_openai_completion_llm", + "create_openai_embedding_llm", + # Limiters + "create_tpm_rpm_limiters", +] diff --git a/func-app/graphrag/llm/base/__init__.py b/func-app/graphrag/llm/base/__init__.py new file mode 100644 index 0000000000..dd5ebf9050 --- /dev/null +++ b/func-app/graphrag/llm/base/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base LLM Implementations.""" + +from .base_llm import BaseLLM +from .caching_llm import CachingLLM +from .rate_limiting_llm import RateLimitingLLM + +__all__ = ["BaseLLM", "CachingLLM", "RateLimitingLLM"] diff --git a/func-app/graphrag/llm/base/_create_cache_key.py b/func-app/graphrag/llm/base/_create_cache_key.py new file mode 100644 index 0000000000..b5fdd839bc --- /dev/null +++ b/func-app/graphrag/llm/base/_create_cache_key.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Cache key generation utils.""" + +import hashlib +import json + + +def _llm_string(params: dict) -> str: + # New version of the cache is not including n in the params dictionary + # This avoids creating a new cache key for the same prompt + if "max_tokens" in params and "n" not in params: + params["n"] = None + return str(sorted((k, v) for k, v in params.items())) + + +def _hash(_input: str) -> str: + """Use a deterministic hashing approach.""" + return hashlib.md5(_input.encode()).hexdigest() # noqa S324 + + +def create_hash_key( + operation: str, prompt: str, parameters: dict, history: list[dict] | None +) -> str: + """Compute cache key from prompt and associated model and settings. + + Args: + prompt (str): The prompt run through the language model. + llm_string (str): The language model version and settings. + + Returns + ------- + str: The cache key. + """ + llm_string = _llm_string(parameters) + history_string = _hash(json.dumps(history)) if history else None + hash_string = ( + _hash(prompt + llm_string + history_string) + if history_string + else _hash(prompt + llm_string) + ) + return f"{operation}-{hash_string}" diff --git a/func-app/graphrag/llm/base/base_llm.py b/func-app/graphrag/llm/base/base_llm.py new file mode 100644 index 0000000000..66f1919cd4 --- /dev/null +++ b/func-app/graphrag/llm/base/base_llm.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base LLM class definition.""" + +import traceback +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from typing_extensions import Unpack + +from graphrag.llm.types import ( + LLM, + ErrorHandlerFn, + LLMInput, + LLMOutput, +) + +TIn = TypeVar("TIn") +TOut = TypeVar("TOut") + + +class BaseLLM(ABC, LLM[TIn, TOut], Generic[TIn, TOut]): + """LLM Implementation class definition.""" + + _on_error: ErrorHandlerFn | None + + def on_error(self, on_error: ErrorHandlerFn | None) -> None: + """Set the error handler function.""" + self._on_error = on_error + + @abstractmethod + async def _execute_llm( + self, + input: TIn, + **kwargs: Unpack[LLMInput], + ) -> TOut | None: + pass + + async def __call__( + self, + input: TIn, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[TOut]: + """Invoke the LLM.""" + is_json = kwargs.get("json") or False + if is_json: + return await self._invoke_json(input, **kwargs) + return await self._invoke(input, **kwargs) + + async def _invoke(self, input: TIn, **kwargs: Unpack[LLMInput]) -> LLMOutput[TOut]: + try: + output = await self._execute_llm(input, **kwargs) + return LLMOutput(output=output) + except Exception as e: + stack_trace = traceback.format_exc() + if self._on_error: + self._on_error(e, stack_trace, {"input": input}) + raise + + async def _invoke_json( + self, input: TIn, **kwargs: Unpack[LLMInput] + ) -> LLMOutput[TOut]: + msg = "JSON output not supported by this LLM" + raise NotImplementedError(msg) diff --git a/func-app/graphrag/llm/base/caching_llm.py b/func-app/graphrag/llm/base/caching_llm.py new file mode 100644 index 0000000000..c039de5122 --- /dev/null +++ b/func-app/graphrag/llm/base/caching_llm.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A class to interact with the cache.""" + +import json +from typing import Generic, TypeVar + +from typing_extensions import Unpack + +from graphrag.llm.types import LLM, LLMCache, LLMInput, LLMOutput, OnCacheActionFn + +from ._create_cache_key import create_hash_key + +# If there's a breaking change in what we cache, we should increment this version number to invalidate existing caches +_cache_strategy_version = 2 + +TIn = TypeVar("TIn") +TOut = TypeVar("TOut") + + +def _noop_cache_fn(_k: str, _v: str | None): + pass + + +class CachingLLM(LLM[TIn, TOut], Generic[TIn, TOut]): + """A class to interact with the cache.""" + + _cache: LLMCache + _delegate: LLM[TIn, TOut] + _operation: str + _llm_parameters: dict + _on_cache_hit: OnCacheActionFn + _on_cache_miss: OnCacheActionFn + + def __init__( + self, + delegate: LLM[TIn, TOut], + llm_parameters: dict, + operation: str, + cache: LLMCache, + ): + self._delegate = delegate + self._llm_parameters = llm_parameters + self._cache = cache + self._operation = operation + self._on_cache_hit = _noop_cache_fn + self._on_cache_miss = _noop_cache_fn + + def set_delegate(self, delegate: LLM[TIn, TOut]) -> None: + """Set the delegate LLM. (for testing).""" + self._delegate = delegate + + def on_cache_hit(self, fn: OnCacheActionFn | None) -> None: + """Set the function to call when a cache hit occurs.""" + self._on_cache_hit = fn or _noop_cache_fn + + def on_cache_miss(self, fn: OnCacheActionFn | None) -> None: + """Set the function to call when a cache miss occurs.""" + self._on_cache_miss = fn or _noop_cache_fn + + def _cache_key( + self, input: TIn, name: str | None, args: dict, history: list[dict] | None + ) -> str: + json_input = json.dumps(input) + tag = ( + f"{name}-{self._operation}-v{_cache_strategy_version}" + if name is not None + else self._operation + ) + return create_hash_key(tag, json_input, args, history) + + async def __call__( + self, + input: TIn, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[TOut]: + """Execute the LLM.""" + # Check for an Existing cache item + name = kwargs.get("name") + history_in = kwargs.get("history") or None + llm_args = {**self._llm_parameters, **(kwargs.get("model_parameters") or {})} + cache_key = self._cache_key(input, name, llm_args, history_in) + cached_result = await self._cache.get(cache_key) + + if cached_result: + self._on_cache_hit(cache_key, name) + return LLMOutput( + output=cached_result, + ) + + # Report the Cache Miss + self._on_cache_miss(cache_key, name) + + # Compute the new result + result = await self._delegate(input, **kwargs) + + # Cache the new result + if result.output is not None: + await self._cache.set( + cache_key, + result.output, + { + "input": input, + "parameters": llm_args, + "history": history_in, + }, + ) + return result diff --git a/func-app/graphrag/llm/base/rate_limiting_llm.py b/func-app/graphrag/llm/base/rate_limiting_llm.py new file mode 100644 index 0000000000..5e2082475f --- /dev/null +++ b/func-app/graphrag/llm/base/rate_limiting_llm.py @@ -0,0 +1,208 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Rate limiting LLM implementation.""" + +import asyncio +import logging +from collections.abc import Callable +from typing import Any, Generic, TypeVar + +from tenacity import ( + AsyncRetrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) +from typing_extensions import Unpack + +from graphrag.llm.errors import RetriesExhaustedError +from graphrag.llm.limiting import LLMLimiter +from graphrag.llm.types import ( + LLM, + LLMConfig, + LLMInput, + LLMInvocationFn, + LLMInvocationResult, + LLMOutput, +) + +TIn = TypeVar("TIn") +TOut = TypeVar("TOut") +TRateLimitError = TypeVar("TRateLimitError", bound=BaseException) + +_CANNOT_MEASURE_INPUT_TOKENS_MSG = "cannot measure input tokens" +_CANNOT_MEASURE_OUTPUT_TOKENS_MSG = "cannot measure output tokens" + +log = logging.getLogger(__name__) + + +class RateLimitingLLM(LLM[TIn, TOut], Generic[TIn, TOut]): + """A class to interact with the cache.""" + + _delegate: LLM[TIn, TOut] + _rate_limiter: LLMLimiter | None + _semaphore: asyncio.Semaphore | None + _count_tokens: Callable[[str], int] + _config: LLMConfig + _operation: str + _retryable_errors: list[type[Exception]] + _rate_limit_errors: list[type[Exception]] + _on_invoke: LLMInvocationFn + _extract_sleep_recommendation: Callable[[Any], float] + + def __init__( + self, + delegate: LLM[TIn, TOut], + config: LLMConfig, + operation: str, + retryable_errors: list[type[Exception]], + rate_limit_errors: list[type[Exception]], + rate_limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + count_tokens: Callable[[str], int] | None = None, + get_sleep_time: Callable[[BaseException], float] | None = None, + ): + self._delegate = delegate + self._rate_limiter = rate_limiter + self._semaphore = semaphore + self._config = config + self._operation = operation + self._retryable_errors = retryable_errors + self._rate_limit_errors = rate_limit_errors + self._count_tokens = count_tokens or (lambda _s: -1) + self._extract_sleep_recommendation = get_sleep_time or (lambda _e: 0.0) + self._on_invoke = lambda _v: None + + def on_invoke(self, fn: LLMInvocationFn | None) -> None: + """Set the on_invoke function.""" + self._on_invoke = fn or (lambda _v: None) + + def count_request_tokens(self, input: TIn) -> int: + """Count the request tokens on an input request.""" + if isinstance(input, str): + return self._count_tokens(input) + if isinstance(input, list): + result = 0 + for item in input: + if isinstance(item, str): + result += self._count_tokens(item) + elif isinstance(item, dict): + result += self._count_tokens(item.get("content", "")) + else: + raise TypeError(_CANNOT_MEASURE_INPUT_TOKENS_MSG) + return result + raise TypeError(_CANNOT_MEASURE_INPUT_TOKENS_MSG) + + def count_response_tokens(self, output: TOut | None) -> int: + """Count the request tokens on an output response.""" + if output is None: + return 0 + if isinstance(output, str): + return self._count_tokens(output) + if isinstance(output, list) and all(isinstance(x, str) for x in output): + return sum(self._count_tokens(item) for item in output) + if isinstance(output, list): + # Embedding response, don't count it + return 0 + raise TypeError(_CANNOT_MEASURE_OUTPUT_TOKENS_MSG) + + async def __call__( + self, + input: TIn, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[TOut]: + """Execute the LLM with semaphore & rate limiting.""" + name = kwargs.get("name", "Process") + attempt_number = 0 + call_times: list[float] = [] + input_tokens = self.count_request_tokens(input) + max_retries = self._config.max_retries or 10 + max_retry_wait = self._config.max_retry_wait or 10 + follow_recommendation = self._config.sleep_on_rate_limit_recommendation + retryer = AsyncRetrying( + stop=stop_after_attempt(max_retries), + wait=wait_exponential_jitter(max=max_retry_wait), + reraise=True, + retry=retry_if_exception_type(tuple(self._retryable_errors)), + ) + + async def sleep_for(time: float | None) -> None: + log.warning( + "%s failed to invoke LLM %s/%s attempts. Cause: rate limit exceeded, will retry. Recommended sleep for %d seconds. Follow recommendation? %s", + name, + attempt_number, + max_retries, + time, + follow_recommendation, + ) + if follow_recommendation and time: + await asyncio.sleep(time) + raise + + async def do_attempt() -> LLMOutput[TOut]: + nonlocal call_times + call_start = asyncio.get_event_loop().time() + try: + return await self._delegate(input, **kwargs) + except BaseException as e: + if isinstance(e, tuple(self._rate_limit_errors)): + sleep_time = self._extract_sleep_recommendation(e) + await sleep_for(sleep_time) + raise + finally: + call_end = asyncio.get_event_loop().time() + call_times.append(call_end - call_start) + + async def execute_with_retry() -> tuple[LLMOutput[TOut], float]: + nonlocal attempt_number + async for attempt in retryer: + with attempt: + if self._rate_limiter and input_tokens > 0: + await self._rate_limiter.acquire(input_tokens) + start = asyncio.get_event_loop().time() + attempt_number += 1 + return await do_attempt(), start + + log.error("Retries exhausted for %s", name) + raise RetriesExhaustedError(name, max_retries) + + result: LLMOutput[TOut] + start = 0.0 + + if self._semaphore is None: + result, start = await execute_with_retry() + else: + async with self._semaphore: + result, start = await execute_with_retry() + + end = asyncio.get_event_loop().time() + output_tokens = self.count_response_tokens(result.output) + if self._rate_limiter and output_tokens > 0: + await self._rate_limiter.acquire(output_tokens) + + invocation_result = LLMInvocationResult( + result=result, + name=name, + num_retries=attempt_number - 1, + total_time=end - start, + call_times=call_times, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + self._handle_invoke_result(invocation_result) + return result + + def _handle_invoke_result( + self, result: LLMInvocationResult[LLMOutput[TOut]] + ) -> None: + log.info( + 'perf - llm.%s "%s" with %s retries took %s. input_tokens=%d, output_tokens=%d', + self._operation, + result.name, + result.num_retries, + result.total_time, + result.input_tokens, + result.output_tokens, + ) + self._on_invoke(result) diff --git a/func-app/graphrag/llm/errors.py b/func-app/graphrag/llm/errors.py new file mode 100644 index 0000000000..01136359de --- /dev/null +++ b/func-app/graphrag/llm/errors.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Error definitions for the OpenAI DataShaper package.""" + + +class RetriesExhaustedError(RuntimeError): + """Retries exhausted error.""" + + def __init__(self, name: str, num_retries: int) -> None: + """Init method definition.""" + super().__init__(f"Operation '{name}' failed - {num_retries} retries exhausted") diff --git a/func-app/graphrag/llm/limiting/__init__.py b/func-app/graphrag/llm/limiting/__init__.py new file mode 100644 index 0000000000..4f7933d1a8 --- /dev/null +++ b/func-app/graphrag/llm/limiting/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM limiters module.""" + +from .composite_limiter import CompositeLLMLimiter +from .create_limiters import create_tpm_rpm_limiters +from .llm_limiter import LLMLimiter +from .noop_llm_limiter import NoopLLMLimiter +from .tpm_rpm_limiter import TpmRpmLLMLimiter + +__all__ = [ + "CompositeLLMLimiter", + "LLMLimiter", + "NoopLLMLimiter", + "TpmRpmLLMLimiter", + "create_tpm_rpm_limiters", +] diff --git a/func-app/graphrag/llm/limiting/composite_limiter.py b/func-app/graphrag/llm/limiting/composite_limiter.py new file mode 100644 index 0000000000..7bcf9195b2 --- /dev/null +++ b/func-app/graphrag/llm/limiting/composite_limiter.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing Composite Limiter class definition.""" + +from .llm_limiter import LLMLimiter + + +class CompositeLLMLimiter(LLMLimiter): + """Composite Limiter class definition.""" + + _limiters: list[LLMLimiter] + + def __init__(self, limiters: list[LLMLimiter]): + """Init method definition.""" + self._limiters = limiters + + @property + def needs_token_count(self) -> bool: + """Whether this limiter needs the token count to be passed in.""" + return any(limiter.needs_token_count for limiter in self._limiters) + + async def acquire(self, num_tokens: int = 1) -> None: + """Call method definition.""" + for limiter in self._limiters: + await limiter.acquire(num_tokens) diff --git a/func-app/graphrag/llm/limiting/create_limiters.py b/func-app/graphrag/llm/limiting/create_limiters.py new file mode 100644 index 0000000000..92df11c1a6 --- /dev/null +++ b/func-app/graphrag/llm/limiting/create_limiters.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Create limiters for OpenAI API requests.""" + +import logging + +from aiolimiter import AsyncLimiter + +from graphrag.llm.types import LLMConfig + +from .llm_limiter import LLMLimiter +from .tpm_rpm_limiter import TpmRpmLLMLimiter + +log = logging.getLogger(__name__) + +"""The global TPM limiters.""" + + +def create_tpm_rpm_limiters( + configuration: LLMConfig, +) -> LLMLimiter: + """Get the limiters for a given model name.""" + tpm = configuration.tokens_per_minute + rpm = configuration.requests_per_minute + return TpmRpmLLMLimiter( + None if tpm == 0 else AsyncLimiter(tpm or 50_000), + None if rpm == 0 else AsyncLimiter(rpm or 10_000), + ) diff --git a/func-app/graphrag/llm/limiting/llm_limiter.py b/func-app/graphrag/llm/limiting/llm_limiter.py new file mode 100644 index 0000000000..1264a84be5 --- /dev/null +++ b/func-app/graphrag/llm/limiting/llm_limiter.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Limiting types.""" + +from abc import ABC, abstractmethod + + +class LLMLimiter(ABC): + """LLM Limiter Interface.""" + + @property + @abstractmethod + def needs_token_count(self) -> bool: + """Whether this limiter needs the token count to be passed in.""" + + @abstractmethod + async def acquire(self, num_tokens: int = 1) -> None: + """Acquire a pass through the limiter.""" diff --git a/func-app/graphrag/llm/limiting/noop_llm_limiter.py b/func-app/graphrag/llm/limiting/noop_llm_limiter.py new file mode 100644 index 0000000000..5147055255 --- /dev/null +++ b/func-app/graphrag/llm/limiting/noop_llm_limiter.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""TPM RPM Limiter module.""" + +from .llm_limiter import LLMLimiter + + +class NoopLLMLimiter(LLMLimiter): + """TPM RPM Limiter class definition.""" + + @property + def needs_token_count(self) -> bool: + """Whether this limiter needs the token count to be passed in.""" + return False + + async def acquire(self, num_tokens: int = 1) -> None: + """Call method definition.""" + # do nothing diff --git a/func-app/graphrag/llm/limiting/tpm_rpm_limiter.py b/func-app/graphrag/llm/limiting/tpm_rpm_limiter.py new file mode 100644 index 0000000000..cb6d84e377 --- /dev/null +++ b/func-app/graphrag/llm/limiting/tpm_rpm_limiter.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""TPM RPM Limiter module.""" + +from aiolimiter import AsyncLimiter + +from .llm_limiter import LLMLimiter + + +class TpmRpmLLMLimiter(LLMLimiter): + """TPM RPM Limiter class definition.""" + + _tpm_limiter: AsyncLimiter | None + _rpm_limiter: AsyncLimiter | None + + def __init__( + self, tpm_limiter: AsyncLimiter | None, rpm_limiter: AsyncLimiter | None + ): + """Init method definition.""" + self._tpm_limiter = tpm_limiter + self._rpm_limiter = rpm_limiter + + @property + def needs_token_count(self) -> bool: + """Whether this limiter needs the token count to be passed in.""" + return self._tpm_limiter is not None + + async def acquire(self, num_tokens: int = 1) -> None: + """Call method definition.""" + if self._tpm_limiter is not None: + await self._tpm_limiter.acquire(num_tokens) + if self._rpm_limiter is not None: + await self._rpm_limiter.acquire() diff --git a/func-app/graphrag/llm/mock/__init__.py b/func-app/graphrag/llm/mock/__init__.py new file mode 100644 index 0000000000..cd1f000dd1 --- /dev/null +++ b/func-app/graphrag/llm/mock/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Mock LLM Implementations.""" + +from .mock_chat_llm import MockChatLLM +from .mock_completion_llm import MockCompletionLLM + +__all__ = [ + "MockChatLLM", + "MockCompletionLLM", +] diff --git a/func-app/graphrag/llm/mock/mock_chat_llm.py b/func-app/graphrag/llm/mock/mock_chat_llm.py new file mode 100644 index 0000000000..b8a6650b31 --- /dev/null +++ b/func-app/graphrag/llm/mock/mock_chat_llm.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A mock ChatLLM that returns the given responses.""" + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, + LLMOutput, +) + + +class MockChatLLM( + BaseLLM[ + CompletionInput, + CompletionOutput, + ] +): + """A mock LLM that returns the given responses.""" + + responses: list[str] + i: int = 0 + + def __init__(self, responses: list[str]): + self.i = 0 + self.responses = responses + + def _create_output( + self, + output: CompletionOutput | None, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + history = kwargs.get("history") or [] + return LLMOutput[CompletionOutput]( + output=output, history=[*history, {"content": output}] + ) + + async def _execute_llm( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> CompletionOutput: + if self.i >= len(self.responses): + msg = f"No more responses, requested {self.i} but only have {len(self.responses)}" + raise ValueError(msg) + response = self.responses[self.i] + self.i += 1 + return response diff --git a/func-app/graphrag/llm/mock/mock_completion_llm.py b/func-app/graphrag/llm/mock/mock_completion_llm.py new file mode 100644 index 0000000000..8cb8e95083 --- /dev/null +++ b/func-app/graphrag/llm/mock/mock_completion_llm.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Static Response method definition.""" + +import logging + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, +) + +log = logging.getLogger(__name__) + + +class MockCompletionLLM( + BaseLLM[ + CompletionInput, + CompletionOutput, + ] +): + """Mock Completion LLM for testing purposes.""" + + def __init__(self, responses: list[str]): + self.responses = responses + self._on_error = None + + async def _execute_llm( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> CompletionOutput: + return self.responses[0] diff --git a/func-app/graphrag/llm/openai/__init__.py b/func-app/graphrag/llm/openai/__init__.py new file mode 100644 index 0000000000..9478e146d2 --- /dev/null +++ b/func-app/graphrag/llm/openai/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""OpenAI LLM implementations.""" + +from .create_openai_client import create_openai_client +from .factories import ( + create_openai_chat_llm, + create_openai_completion_llm, + create_openai_embedding_llm, +) +from .openai_chat_llm import OpenAIChatLLM +from .openai_completion_llm import OpenAICompletionLLM +from .openai_configuration import OpenAIConfiguration +from .openai_embeddings_llm import OpenAIEmbeddingsLLM +from .types import OpenAIClientTypes + +__all__ = [ + "OpenAIChatLLM", + "OpenAIClientTypes", + "OpenAICompletionLLM", + "OpenAIConfiguration", + "OpenAIEmbeddingsLLM", + "create_openai_chat_llm", + "create_openai_client", + "create_openai_completion_llm", + "create_openai_embedding_llm", +] diff --git a/func-app/graphrag/llm/openai/_prompts.py b/func-app/graphrag/llm/openai/_prompts.py new file mode 100644 index 0000000000..37d9f0fc70 --- /dev/null +++ b/func-app/graphrag/llm/openai/_prompts.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utility prompts for low-level LLM invocations.""" + +JSON_CHECK_PROMPT = """ +You are going to be given a malformed JSON string that threw an error during json.loads. +It probably contains unnecessary escape sequences, or it is missing a comma or colon somewhere. +Your task is to fix this string and return a well-formed JSON string containing a single object. +Eliminate any unnecessary escape sequences. +Only return valid JSON, parseable with json.loads, without commentary. + +# Examples +----------- +Text: {{ \\"title\\": \\"abc\\", \\"summary\\": \\"def\\" }} +Output: {{"title": "abc", "summary": "def"}} +----------- +Text: {{"title": "abc", "summary": "def" +Output: {{"title": "abc", "summary": "def"}} +----------- +Text: {{"title': "abc", 'summary": "def" +Output: {{"title": "abc", "summary": "def"}} +----------- +Text: "{{"title": "abc", "summary": "def"}}" +Output: {{"title": "abc", "summary": "def"}} +----------- +Text: [{{"title": "abc", "summary": "def"}}] +Output: [{{"title": "abc", "summary": "def"}}] +----------- +Text: [{{"title": "abc", "summary": "def"}}, {{ \\"title\\": \\"abc\\", \\"summary\\": \\"def\\" }}] +Output: [{{"title": "abc", "summary": "def"}}, {{"title": "abc", "summary": "def"}}] +----------- +Text: ```json\n[{{"title": "abc", "summary": "def"}}, {{ \\"title\\": \\"abc\\", \\"summary\\": \\"def\\" }}]``` +Output: [{{"title": "abc", "summary": "def"}}, {{"title": "abc", "summary": "def"}}] + + +# Real Data +Text: {input_text} +Output:""" diff --git a/func-app/graphrag/llm/openai/create_openai_client.py b/func-app/graphrag/llm/openai/create_openai_client.py new file mode 100644 index 0000000000..cd149323c6 --- /dev/null +++ b/func-app/graphrag/llm/openai/create_openai_client.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Create OpenAI client instance.""" + +import logging +from functools import cache + +from azure.identity import ManagedIdentityCredential, get_bearer_token_provider +from openai import AsyncAzureOpenAI, AsyncOpenAI + +from .openai_configuration import OpenAIConfiguration +from .types import OpenAIClientTypes + +log = logging.getLogger(__name__) + +API_BASE_REQUIRED_FOR_AZURE = "api_base is required for Azure OpenAI client" + + +@cache +def create_openai_client( + configuration: OpenAIConfiguration, azure: bool +) -> OpenAIClientTypes: + """Create a new OpenAI client instance.""" + if azure: + api_base = configuration.api_base + if api_base is None: + raise ValueError(API_BASE_REQUIRED_FOR_AZURE) + + log.info( + "Creating Azure OpenAI client api_base=%s, deployment_name=%s", + api_base, + configuration.deployment_name, + ) + if configuration.cognitive_services_endpoint is None: + cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default" + else: + cognitive_services_endpoint = configuration.cognitive_services_endpoint + + return AsyncAzureOpenAI( + api_key=configuration.api_key if configuration.api_key else None, + azure_ad_token_provider=get_bearer_token_provider( + ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint + ) + if not configuration.api_key + else None, + organization=configuration.organization, + # Azure-Specifics + api_version=configuration.api_version, + azure_endpoint=api_base, + azure_deployment=configuration.deployment_name, + # Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here + timeout=configuration.request_timeout or 180.0, + max_retries=0, + ) + + log.info("Creating OpenAI client base_url=%s", configuration.api_base) + return AsyncOpenAI( + api_key=configuration.api_key, + base_url=configuration.api_base, + organization=configuration.organization, + # Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here + timeout=configuration.request_timeout or 180.0, + max_retries=0, + ) diff --git a/func-app/graphrag/llm/openai/factories.py b/func-app/graphrag/llm/openai/factories.py new file mode 100644 index 0000000000..e595e2e55b --- /dev/null +++ b/func-app/graphrag/llm/openai/factories.py @@ -0,0 +1,140 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Factory functions for creating OpenAI LLMs.""" + +import asyncio + +from graphrag.llm.base import CachingLLM, RateLimitingLLM +from graphrag.llm.limiting import LLMLimiter +from graphrag.llm.types import ( + LLM, + CompletionLLM, + EmbeddingLLM, + ErrorHandlerFn, + LLMCache, + LLMInvocationFn, + OnCacheActionFn, +) + +from .json_parsing_llm import JsonParsingLLM +from .openai_chat_llm import OpenAIChatLLM +from .openai_completion_llm import OpenAICompletionLLM +from .openai_configuration import OpenAIConfiguration +from .openai_embeddings_llm import OpenAIEmbeddingsLLM +from .openai_history_tracking_llm import OpenAIHistoryTrackingLLM +from .openai_token_replacing_llm import OpenAITokenReplacingLLM +from .types import OpenAIClientTypes +from .utils import ( + RATE_LIMIT_ERRORS, + RETRYABLE_ERRORS, + get_completion_cache_args, + get_sleep_time_from_error, + get_token_counter, +) + + +def create_openai_chat_llm( + client: OpenAIClientTypes, + config: OpenAIConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> CompletionLLM: + """Create an OpenAI chat LLM.""" + operation = "chat" + result = OpenAIChatLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + result = OpenAIHistoryTrackingLLM(result) + result = OpenAITokenReplacingLLM(result) + return JsonParsingLLM(result) + + +def create_openai_completion_llm( + client: OpenAIClientTypes, + config: OpenAIConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> CompletionLLM: + """Create an OpenAI completion LLM.""" + operation = "completion" + result = OpenAICompletionLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + return OpenAITokenReplacingLLM(result) + + +def create_openai_embedding_llm( + client: OpenAIClientTypes, + config: OpenAIConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> EmbeddingLLM: + """Create an OpenAI embeddings LLM.""" + operation = "embedding" + result = OpenAIEmbeddingsLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + return result + + +def _rate_limited( + delegate: LLM, + config: OpenAIConfiguration, + operation: str, + limiter: LLMLimiter | None, + semaphore: asyncio.Semaphore | None, + on_invoke: LLMInvocationFn | None, +): + result = RateLimitingLLM( + delegate, + config, + operation, + RETRYABLE_ERRORS, + RATE_LIMIT_ERRORS, + limiter, + semaphore, + get_token_counter(config), + get_sleep_time_from_error, + ) + result.on_invoke(on_invoke) + return result + + +def _cached( + delegate: LLM, + config: OpenAIConfiguration, + operation: str, + cache: LLMCache, + on_cache_hit: OnCacheActionFn | None, + on_cache_miss: OnCacheActionFn | None, +): + cache_args = get_completion_cache_args(config) + result = CachingLLM(delegate, cache_args, operation, cache) + result.on_cache_hit(on_cache_hit) + result.on_cache_miss(on_cache_miss) + return result diff --git a/func-app/graphrag/llm/openai/json_parsing_llm.py b/func-app/graphrag/llm/openai/json_parsing_llm.py new file mode 100644 index 0000000000..009c1da42e --- /dev/null +++ b/func-app/graphrag/llm/openai/json_parsing_llm.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""An LLM that unpacks cached JSON responses.""" + +from typing_extensions import Unpack + +from graphrag.llm.types import ( + LLM, + CompletionInput, + CompletionLLM, + CompletionOutput, + LLMInput, + LLMOutput, +) + +from .utils import try_parse_json_object + + +class JsonParsingLLM(LLM[CompletionInput, CompletionOutput]): + """An OpenAI History-Tracking LLM.""" + + _delegate: CompletionLLM + + def __init__(self, delegate: CompletionLLM): + self._delegate = delegate + + async def __call__( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Call the LLM with the input and kwargs.""" + result = await self._delegate(input, **kwargs) + if kwargs.get("json") and result.json is None and result.output is not None: + _, parsed_json = try_parse_json_object(result.output) + result.json = parsed_json + return result diff --git a/func-app/graphrag/llm/openai/openai_chat_llm.py b/func-app/graphrag/llm/openai/openai_chat_llm.py new file mode 100644 index 0000000000..d08f8af80c --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_chat_llm.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Chat-based language model.""" + +import logging + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, + LLMOutput, +) + +from ._prompts import JSON_CHECK_PROMPT +from .openai_configuration import OpenAIConfiguration +from .types import OpenAIClientTypes +from .utils import ( + get_completion_llm_args, + try_parse_json_object, +) + +log = logging.getLogger(__name__) + +_MAX_GENERATION_RETRIES = 3 +FAILED_TO_CREATE_JSON_ERROR = "Failed to generate valid JSON output" + + +class OpenAIChatLLM(BaseLLM[CompletionInput, CompletionOutput]): + """A Chat-based LLM.""" + + _client: OpenAIClientTypes + _configuration: OpenAIConfiguration + + def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, input: CompletionInput, **kwargs: Unpack[LLMInput] + ) -> CompletionOutput | None: + args = get_completion_llm_args( + kwargs.get("model_parameters"), self.configuration + ) + history = kwargs.get("history") or [] + messages = [ + *history, + {"role": "user", "content": input}, + ] + completion = await self.client.chat.completions.create( + messages=messages, **args + ) + return completion.choices[0].message.content + + async def _invoke_json( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Generate JSON output.""" + name = kwargs.get("name") or "unknown" + is_response_valid = kwargs.get("is_response_valid") or (lambda _x: True) + + async def generate( + attempt: int | None = None, + ) -> LLMOutput[CompletionOutput]: + call_name = name if attempt is None else f"{name}@{attempt}" + return ( + await self._native_json(input, **{**kwargs, "name": call_name}) + if self.configuration.model_supports_json + else await self._manual_json(input, **{**kwargs, "name": call_name}) + ) + + def is_valid(x: dict | None) -> bool: + return x is not None and is_response_valid(x) + + result = await generate() + retry = 0 + while not is_valid(result.json) and retry < _MAX_GENERATION_RETRIES: + result = await generate(retry) + retry += 1 + + if is_valid(result.json): + return result + raise RuntimeError(FAILED_TO_CREATE_JSON_ERROR) + + async def _native_json( + self, input: CompletionInput, **kwargs: Unpack[LLMInput] + ) -> LLMOutput[CompletionOutput]: + """Generate JSON output using a model's native JSON-output support.""" + result = await self._invoke( + input, + **{ + **kwargs, + "model_parameters": { + **(kwargs.get("model_parameters") or {}), + "response_format": {"type": "json_object"}, + }, + }, + ) + + output, json_output = try_parse_json_object(result.output or "") + + return LLMOutput[CompletionOutput]( + output=output, + json=json_output, + history=result.history, + ) + + async def _manual_json( + self, input: CompletionInput, **kwargs: Unpack[LLMInput] + ) -> LLMOutput[CompletionOutput]: + # Otherwise, clean up the output and try to parse it as json + result = await self._invoke(input, **kwargs) + history = result.history or [] + output, json_output = try_parse_json_object(result.output or "") + if json_output: + return LLMOutput[CompletionOutput]( + output=result.output, json=json_output, history=history + ) + # if not return correct formatted json, retry + log.warning("error parsing llm json, retrying") + + # If cleaned up json is unparsable, use the LLM to reformat it (may throw) + result = await self._try_clean_json_with_llm(output, **kwargs) + output, json_output = try_parse_json_object(result.output or "") + + return LLMOutput[CompletionOutput]( + output=output, + json=json_output, + history=history, + ) + + async def _try_clean_json_with_llm( + self, output: str, **kwargs: Unpack[LLMInput] + ) -> LLMOutput[CompletionOutput]: + name = kwargs.get("name") or "unknown" + return await self._invoke( + JSON_CHECK_PROMPT, + **{ + **kwargs, + "variables": {"input_text": output}, + "name": f"fix_json@{name}", + }, + ) diff --git a/func-app/graphrag/llm/openai/openai_completion_llm.py b/func-app/graphrag/llm/openai/openai_completion_llm.py new file mode 100644 index 0000000000..bdbac6c131 --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_completion_llm.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A text-completion based LLM.""" + +import logging + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, +) + +from .openai_configuration import OpenAIConfiguration +from .types import OpenAIClientTypes +from .utils import get_completion_llm_args + +log = logging.getLogger(__name__) + + +class OpenAICompletionLLM(BaseLLM[CompletionInput, CompletionOutput]): + """A text-completion based LLM.""" + + _client: OpenAIClientTypes + _configuration: OpenAIConfiguration + + def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> CompletionOutput | None: + args = get_completion_llm_args( + kwargs.get("model_parameters"), self.configuration + ) + completion = self.client.completions.create(prompt=input, **args) + return completion.choices[0].text diff --git a/func-app/graphrag/llm/openai/openai_configuration.py b/func-app/graphrag/llm/openai/openai_configuration.py new file mode 100644 index 0000000000..1bcd5694d6 --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_configuration.py @@ -0,0 +1,288 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""OpenAI Configuration class definition.""" + +import json +from collections.abc import Hashable +from typing import Any, cast + +from graphrag.llm.types import LLMConfig + + +def _non_blank(value: str | None) -> str | None: + if value is None: + return None + stripped = value.strip() + return None if stripped == "" else value + + +class OpenAIConfiguration(Hashable, LLMConfig): + """OpenAI Configuration class definition.""" + + # Core Configuration + _api_key: str + _model: str + + _api_base: str | None + _api_version: str | None + _cognitive_services_endpoint: str | None + _deployment_name: str | None + _organization: str | None + _proxy: str | None + + # Operation Configuration + _n: int | None + _temperature: float | None + _frequency_penalty: float | None + _presence_penalty: float | None + _top_p: float | None + _max_tokens: int | None + _response_format: str | None + _logit_bias: dict[str, float] | None + _stop: list[str] | None + + # Retry Logic + _max_retries: int | None + _max_retry_wait: float | None + _request_timeout: float | None + + # The raw configuration object + _raw_config: dict + + # Feature Flags + _model_supports_json: bool | None + + # Custom Configuration + _tokens_per_minute: int | None + _requests_per_minute: int | None + _concurrent_requests: int | None + _encoding_model: str | None + _sleep_on_rate_limit_recommendation: bool | None + + def __init__( + self, + config: dict, + ): + """Init method definition.""" + + def lookup_required(key: str) -> str: + return cast(str, config.get(key)) + + def lookup_str(key: str) -> str | None: + return cast(str | None, config.get(key)) + + def lookup_int(key: str) -> int | None: + result = config.get(key) + if result is None: + return None + return int(cast(int, result)) + + def lookup_float(key: str) -> float | None: + result = config.get(key) + if result is None: + return None + return float(cast(float, result)) + + def lookup_dict(key: str) -> dict | None: + return cast(dict | None, config.get(key)) + + def lookup_list(key: str) -> list | None: + return cast(list | None, config.get(key)) + + def lookup_bool(key: str) -> bool | None: + value = config.get(key) + if isinstance(value, str): + return value.upper() == "TRUE" + if isinstance(value, int): + return value > 0 + return cast(bool | None, config.get(key)) + + self._api_key = lookup_required("api_key") + self._model = lookup_required("model") + self._deployment_name = lookup_str("deployment_name") + self._api_base = lookup_str("api_base") + self._api_version = lookup_str("api_version") + self._cognitive_services_endpoint = lookup_str("cognitive_services_endpoint") + self._organization = lookup_str("organization") + self._proxy = lookup_str("proxy") + self._n = lookup_int("n") + self._temperature = lookup_float("temperature") + self._frequency_penalty = lookup_float("frequency_penalty") + self._presence_penalty = lookup_float("presence_penalty") + self._top_p = lookup_float("top_p") + self._max_tokens = lookup_int("max_tokens") + self._response_format = lookup_str("response_format") + self._logit_bias = lookup_dict("logit_bias") + self._stop = lookup_list("stop") + self._max_retries = lookup_int("max_retries") + self._request_timeout = lookup_float("request_timeout") + self._model_supports_json = lookup_bool("model_supports_json") + self._tokens_per_minute = lookup_int("tokens_per_minute") + self._requests_per_minute = lookup_int("requests_per_minute") + self._concurrent_requests = lookup_int("concurrent_requests") + self._encoding_model = lookup_str("encoding_model") + self._max_retry_wait = lookup_float("max_retry_wait") + self._sleep_on_rate_limit_recommendation = lookup_bool( + "sleep_on_rate_limit_recommendation" + ) + self._raw_config = config + + @property + def api_key(self) -> str: + """API key property definition.""" + return self._api_key + + @property + def model(self) -> str: + """Model property definition.""" + return self._model + + @property + def deployment_name(self) -> str | None: + """Deployment name property definition.""" + return _non_blank(self._deployment_name) + + @property + def api_base(self) -> str | None: + """API base property definition.""" + result = _non_blank(self._api_base) + # Remove trailing slash + return result[:-1] if result and result.endswith("/") else result + + @property + def api_version(self) -> str | None: + """API version property definition.""" + return _non_blank(self._api_version) + + @property + def cognitive_services_endpoint(self) -> str | None: + """API version property definition.""" + return _non_blank(self._cognitive_services_endpoint) + + @property + def organization(self) -> str | None: + """Organization property definition.""" + return _non_blank(self._organization) + + @property + def proxy(self) -> str | None: + """Proxy property definition.""" + return _non_blank(self._proxy) + + @property + def n(self) -> int | None: + """N property definition.""" + return self._n + + @property + def temperature(self) -> float | None: + """Temperature property definition.""" + return self._temperature + + @property + def frequency_penalty(self) -> float | None: + """Frequency penalty property definition.""" + return self._frequency_penalty + + @property + def presence_penalty(self) -> float | None: + """Presence penalty property definition.""" + return self._presence_penalty + + @property + def top_p(self) -> float | None: + """Top p property definition.""" + return self._top_p + + @property + def max_tokens(self) -> int | None: + """Max tokens property definition.""" + return self._max_tokens + + @property + def response_format(self) -> str | None: + """Response format property definition.""" + return _non_blank(self._response_format) + + @property + def logit_bias(self) -> dict[str, float] | None: + """Logit bias property definition.""" + return self._logit_bias + + @property + def stop(self) -> list[str] | None: + """Stop property definition.""" + return self._stop + + @property + def max_retries(self) -> int | None: + """Max retries property definition.""" + return self._max_retries + + @property + def max_retry_wait(self) -> float | None: + """Max retry wait property definition.""" + return self._max_retry_wait + + @property + def request_timeout(self) -> float | None: + """Request timeout property definition.""" + return self._request_timeout + + @property + def model_supports_json(self) -> bool | None: + """Model supports json property definition.""" + return self._model_supports_json + + @property + def tokens_per_minute(self) -> int | None: + """Tokens per minute property definition.""" + return self._tokens_per_minute + + @property + def requests_per_minute(self) -> int | None: + """Requests per minute property definition.""" + return self._requests_per_minute + + @property + def concurrent_requests(self) -> int | None: + """Concurrent requests property definition.""" + return self._concurrent_requests + + @property + def encoding_model(self) -> str | None: + """Encoding model property definition.""" + return _non_blank(self._encoding_model) + + @property + def sleep_on_rate_limit_recommendation(self) -> bool | None: + """Whether to sleep for seconds when recommended by 429 errors (azure-specific).""" + return self._sleep_on_rate_limit_recommendation + + @property + def raw_config(self) -> dict: + """Raw config method definition.""" + return self._raw_config + + def lookup(self, name: str, default_value: Any = None) -> Any: + """Lookup method definition.""" + return self._raw_config.get(name, default_value) + + def __str__(self) -> str: + """Str method definition.""" + return json.dumps(self.raw_config, indent=4) + + def __repr__(self) -> str: + """Repr method definition.""" + return f"OpenAIConfiguration({self._raw_config})" + + def __eq__(self, other: object) -> bool: + """Eq method definition.""" + if not isinstance(other, OpenAIConfiguration): + return False + return self._raw_config == other._raw_config + + def __hash__(self) -> int: + """Hash method definition.""" + return hash(tuple(sorted(self._raw_config.items()))) diff --git a/func-app/graphrag/llm/openai/openai_embeddings_llm.py b/func-app/graphrag/llm/openai/openai_embeddings_llm.py new file mode 100644 index 0000000000..558afe8437 --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_embeddings_llm.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The EmbeddingsLLM class.""" + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + EmbeddingInput, + EmbeddingOutput, + LLMInput, +) + +from .openai_configuration import OpenAIConfiguration +from .types import OpenAIClientTypes + + +class OpenAIEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]): + """A text-embedding generator LLM.""" + + _client: OpenAIClientTypes + _configuration: OpenAIConfiguration + + def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, input: EmbeddingInput, **kwargs: Unpack[LLMInput] + ) -> EmbeddingOutput | None: + args = { + "model": self.configuration.model, + **(kwargs.get("model_parameters") or {}), + } + embedding = await self.client.embeddings.create( + input=input, + **args, + ) + return [d.embedding for d in embedding.data] diff --git a/func-app/graphrag/llm/openai/openai_history_tracking_llm.py b/func-app/graphrag/llm/openai/openai_history_tracking_llm.py new file mode 100644 index 0000000000..ab903c2d2a --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_history_tracking_llm.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Chat-based language model.""" + +from typing_extensions import Unpack + +from graphrag.llm.types import ( + LLM, + CompletionInput, + CompletionLLM, + CompletionOutput, + LLMInput, + LLMOutput, +) + + +class OpenAIHistoryTrackingLLM(LLM[CompletionInput, CompletionOutput]): + """An OpenAI History-Tracking LLM.""" + + _delegate: CompletionLLM + + def __init__(self, delegate: CompletionLLM): + self._delegate = delegate + + async def __call__( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Call the LLM.""" + history = kwargs.get("history") or [] + output = await self._delegate(input, **kwargs) + return LLMOutput( + output=output.output, + json=output.json, + history=[ + *history, + {"role": "user", "content": input}, + {"role": "assistant", "content": output.output}, + ], + ) diff --git a/func-app/graphrag/llm/openai/openai_token_replacing_llm.py b/func-app/graphrag/llm/openai/openai_token_replacing_llm.py new file mode 100644 index 0000000000..7385b84059 --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_token_replacing_llm.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Chat-based language model.""" + +from typing_extensions import Unpack + +from graphrag.llm.types import ( + LLM, + CompletionInput, + CompletionLLM, + CompletionOutput, + LLMInput, + LLMOutput, +) + +from .utils import perform_variable_replacements + + +class OpenAITokenReplacingLLM(LLM[CompletionInput, CompletionOutput]): + """An OpenAI History-Tracking LLM.""" + + _delegate: CompletionLLM + + def __init__(self, delegate: CompletionLLM): + self._delegate = delegate + + async def __call__( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Call the LLM with the input and kwargs.""" + variables = kwargs.get("variables") + history = kwargs.get("history") or [] + input = perform_variable_replacements(input, history, variables) + return await self._delegate(input, **kwargs) diff --git a/func-app/graphrag/llm/openai/types.py b/func-app/graphrag/llm/openai/types.py new file mode 100644 index 0000000000..4aacf18c1c --- /dev/null +++ b/func-app/graphrag/llm/openai/types.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A base class for OpenAI-based LLMs.""" + +from openai import ( + AsyncAzureOpenAI, + AsyncOpenAI, +) + +OpenAIClientTypes = AsyncOpenAI | AsyncAzureOpenAI diff --git a/func-app/graphrag/llm/openai/utils.py b/func-app/graphrag/llm/openai/utils.py new file mode 100644 index 0000000000..5d683951b1 --- /dev/null +++ b/func-app/graphrag/llm/openai/utils.py @@ -0,0 +1,160 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utility functions for the OpenAI API.""" + +import json +import logging +import re +from collections.abc import Callable +from typing import Any + +import tiktoken +from json_repair import repair_json +from openai import ( + APIConnectionError, + InternalServerError, + RateLimitError, +) + +from .openai_configuration import OpenAIConfiguration + +DEFAULT_ENCODING = "cl100k_base" + +_encoders: dict[str, tiktoken.Encoding] = {} + +RETRYABLE_ERRORS: list[type[Exception]] = [ + RateLimitError, + APIConnectionError, + InternalServerError, +] +RATE_LIMIT_ERRORS: list[type[Exception]] = [RateLimitError] + +log = logging.getLogger(__name__) + + +def get_token_counter(config: OpenAIConfiguration) -> Callable[[str], int]: + """Get a function that counts the number of tokens in a string.""" + model = config.encoding_model or "cl100k_base" + enc = _encoders.get(model) + if enc is None: + enc = tiktoken.get_encoding(model) + _encoders[model] = enc + + return lambda s: len(enc.encode(s)) + + +def perform_variable_replacements( + input: str, history: list[dict], variables: dict | None +) -> str: + """Perform variable replacements on the input string and in a chat log.""" + result = input + + def replace_all(input: str) -> str: + result = input + if variables: + for entry in variables: + result = result.replace(f"{{{entry}}}", variables[entry]) + return result + + result = replace_all(result) + for i in range(len(history)): + entry = history[i] + if entry.get("role") == "system": + history[i]["content"] = replace_all(entry.get("content") or "") + + return result + + +def get_completion_cache_args(configuration: OpenAIConfiguration) -> dict: + """Get the cache arguments for a completion LLM.""" + return { + "model": configuration.model, + "temperature": configuration.temperature, + "frequency_penalty": configuration.frequency_penalty, + "presence_penalty": configuration.presence_penalty, + "top_p": configuration.top_p, + "max_tokens": configuration.max_tokens, + "n": configuration.n, + } + + +def get_completion_llm_args( + parameters: dict | None, configuration: OpenAIConfiguration +) -> dict: + """Get the arguments for a completion LLM.""" + return { + **get_completion_cache_args(configuration), + **(parameters or {}), + } + + +def try_parse_json_object(input: str) -> tuple[str, dict]: + """JSON cleaning and formatting utilities.""" + # Sometimes, the LLM returns a json string with some extra description, this function will clean it up. + + result = None + try: + # Try parse first + result = json.loads(input) + except json.JSONDecodeError: + log.info("Warning: Error decoding faulty json, attempting repair") + + if result: + return input, result + + _pattern = r"\{(.*)\}" + _match = re.search(_pattern, input) + input = "{" + _match.group(1) + "}" if _match else input + + # Clean up json string. + input = ( + input.replace("{{", "{") + .replace("}}", "}") + .replace('"[{', "[{") + .replace('}]"', "}]") + .replace("\\", " ") + .replace("\\n", " ") + .replace("\n", " ") + .replace("\r", "") + .strip() + ) + + # Remove JSON Markdown Frame + if input.startswith("```json"): + input = input[len("```json") :] + if input.endswith("```"): + input = input[: len(input) - len("```")] + + try: + result = json.loads(input) + except json.JSONDecodeError: + # Fixup potentially malformed json string using json_repair. + input = str(repair_json(json_str=input, return_objects=False)) + + # Generate JSON-string output using best-attempt prompting & parsing techniques. + try: + result = json.loads(input) + except json.JSONDecodeError: + log.exception("error loading json, json=%s", input) + return input, {} + else: + if not isinstance(result, dict): + log.exception("not expected dict type. type=%s:", type(result)) + return input, {} + return input, result + else: + return input, result + + +def get_sleep_time_from_error(e: Any) -> float: + """Extract the sleep time value from a RateLimitError. This is usually only available in Azure.""" + sleep_time = 0.0 + if isinstance(e, RateLimitError) and _please_retry_after in str(e): + # could be second or seconds + sleep_time = int(str(e).split(_please_retry_after)[1].split(" second")[0]) + + return sleep_time + + +_please_retry_after = "Please retry after " diff --git a/func-app/graphrag/llm/types/__init__.py b/func-app/graphrag/llm/types/__init__.py new file mode 100644 index 0000000000..c8277661d5 --- /dev/null +++ b/func-app/graphrag/llm/types/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Typings.""" + +from .llm import LLM +from .llm_cache import LLMCache +from .llm_callbacks import ( + ErrorHandlerFn, + IsResponseValidFn, + LLMInvocationFn, + OnCacheActionFn, +) +from .llm_config import LLMConfig +from .llm_invocation_result import LLMInvocationResult +from .llm_io import ( + LLMInput, + LLMOutput, +) +from .llm_types import ( + CompletionInput, + CompletionLLM, + CompletionOutput, + EmbeddingInput, + EmbeddingLLM, + EmbeddingOutput, +) + +__all__ = [ + "LLM", + "CompletionInput", + "CompletionLLM", + "CompletionOutput", + "EmbeddingInput", + "EmbeddingLLM", + "EmbeddingOutput", + "ErrorHandlerFn", + "IsResponseValidFn", + "LLMCache", + "LLMConfig", + "LLMInput", + "LLMInvocationFn", + "LLMInvocationResult", + "LLMOutput", + "OnCacheActionFn", +] diff --git a/func-app/graphrag/llm/types/llm.py b/func-app/graphrag/llm/types/llm.py new file mode 100644 index 0000000000..fd8407e50e --- /dev/null +++ b/func-app/graphrag/llm/types/llm.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Types.""" + +from typing import Generic, Protocol, TypeVar + +from typing_extensions import Unpack + +from .llm_io import ( + LLMInput, + LLMOutput, +) + +TIn = TypeVar("TIn", contravariant=True) +TOut = TypeVar("TOut") + + +class LLM(Protocol, Generic[TIn, TOut]): + """LLM Protocol definition.""" + + async def __call__( + self, + input: TIn, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[TOut]: + """Invoke the LLM, treating the LLM as a function.""" + ... diff --git a/func-app/graphrag/llm/types/llm_cache.py b/func-app/graphrag/llm/types/llm_cache.py new file mode 100644 index 0000000000..952b8d346d --- /dev/null +++ b/func-app/graphrag/llm/types/llm_cache.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Typing definitions for the OpenAI DataShaper package.""" + +from typing import Any, Protocol + + +class LLMCache(Protocol): + """LLM Cache interface.""" + + async def has(self, key: str) -> bool: + """Check if the cache has a value.""" + ... + + async def get(self, key: str) -> Any | None: + """Retrieve a value from the cache.""" + ... + + async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None: + """Write a value into the cache.""" + ... diff --git a/func-app/graphrag/llm/types/llm_callbacks.py b/func-app/graphrag/llm/types/llm_callbacks.py new file mode 100644 index 0000000000..dc06dbff06 --- /dev/null +++ b/func-app/graphrag/llm/types/llm_callbacks.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Typing definitions for the OpenAI DataShaper package.""" + +from collections.abc import Callable + +from .llm_invocation_result import LLMInvocationResult + +ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] +"""Error handler function type definition.""" + +LLMInvocationFn = Callable[[LLMInvocationResult], None] +"""Handler for LLM invocation results""" + +OnCacheActionFn = Callable[[str, str | None], None] +"""Handler for cache hits""" + +IsResponseValidFn = Callable[[dict], bool] +"""A function that checks if an LLM response is valid.""" diff --git a/func-app/graphrag/llm/types/llm_config.py b/func-app/graphrag/llm/types/llm_config.py new file mode 100644 index 0000000000..cd7ec255b2 --- /dev/null +++ b/func-app/graphrag/llm/types/llm_config.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Configuration Protocol definition.""" + +from typing import Protocol + + +class LLMConfig(Protocol): + """LLM Configuration Protocol definition.""" + + @property + def max_retries(self) -> int | None: + """Get the maximum number of retries.""" + ... + + @property + def max_retry_wait(self) -> float | None: + """Get the maximum retry wait time.""" + ... + + @property + def sleep_on_rate_limit_recommendation(self) -> bool | None: + """Get whether to sleep on rate limit recommendation.""" + ... + + @property + def tokens_per_minute(self) -> int | None: + """Get the number of tokens per minute.""" + ... + + @property + def requests_per_minute(self) -> int | None: + """Get the number of requests per minute.""" + ... diff --git a/func-app/graphrag/llm/types/llm_invocation_result.py b/func-app/graphrag/llm/types/llm_invocation_result.py new file mode 100644 index 0000000000..1769aeb96d --- /dev/null +++ b/func-app/graphrag/llm/types/llm_invocation_result.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Typing definitions for the OpenAI DataShaper package.""" + +from dataclasses import dataclass +from typing import Generic, TypeVar + +T = TypeVar("T") + + +@dataclass +class LLMInvocationResult(Generic[T]): + """The result of an LLM invocation.""" + + result: T | None + """The result of the LLM invocation.""" + + name: str + """The operation name of the result""" + + num_retries: int + """The number of retries the invocation took.""" + + total_time: float + """The total time of the LLM invocation.""" + + call_times: list[float] + """The network times of individual invocations.""" + + input_tokens: int + """The number of input tokens.""" + + output_tokens: int + """The number of output tokens.""" diff --git a/func-app/graphrag/llm/types/llm_io.py b/func-app/graphrag/llm/types/llm_io.py new file mode 100644 index 0000000000..256f3c8ce8 --- /dev/null +++ b/func-app/graphrag/llm/types/llm_io.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Types.""" + +from dataclasses import dataclass, field +from typing import Generic, TypeVar + +from typing_extensions import NotRequired, TypedDict + +from .llm_callbacks import IsResponseValidFn + + +class LLMInput(TypedDict): + """The input of an LLM invocation.""" + + name: NotRequired[str] + """The name of the LLM invocation, if available.""" + + json: NotRequired[bool] + """If true, will attempt to elicit JSON from the LLM. Parsed JSON will be returned in the `json_output` field.""" + + is_response_valid: NotRequired[IsResponseValidFn] + """A function that checks if an LLM response is valid. Only valid if `json=True`.""" + + variables: NotRequired[dict] + """The variable replacements to use in the prompt.""" + + history: NotRequired[list[dict] | None] + """The history of the LLM invocation, if available (e.g. chat mode)""" + + model_parameters: NotRequired[dict] + """Additional model parameters to use in the LLM invocation.""" + + +T = TypeVar("T") + + +@dataclass +class LLMOutput(Generic[T]): + """The output of an LLM invocation.""" + + output: T | None + """The output of the LLM invocation.""" + + json: dict | None = field(default=None) + """The JSON output from the LLM, if available.""" + + history: list[dict] | None = field(default=None) + """The history of the LLM invocation, if available (e.g. chat mode)""" diff --git a/func-app/graphrag/llm/types/llm_types.py b/func-app/graphrag/llm/types/llm_types.py new file mode 100644 index 0000000000..7ae76ef9be --- /dev/null +++ b/func-app/graphrag/llm/types/llm_types.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Types.""" + +from typing import TypeAlias + +from .llm import LLM + +EmbeddingInput: TypeAlias = list[str] +EmbeddingOutput: TypeAlias = list[list[float]] +CompletionInput: TypeAlias = str +CompletionOutput: TypeAlias = str + +EmbeddingLLM: TypeAlias = LLM[EmbeddingInput, EmbeddingOutput] +CompletionLLM: TypeAlias = LLM[CompletionInput, CompletionOutput] diff --git a/func-app/graphrag/model/__init__.py b/func-app/graphrag/model/__init__.py new file mode 100644 index 0000000000..9dbec3d1dd --- /dev/null +++ b/func-app/graphrag/model/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +""" +GraphRAG knowledge model package root. + +The GraphRAG knowledge model contains a set of classes that represent the target datamodels for our pipelines and analytics tools. +These models can be augmented and integrated into your own data infrastructure to suit your needs. +""" + +from .community import Community +from .community_report import CommunityReport +from .covariate import Covariate +from .document import Document +from .entity import Entity +from .identified import Identified +from .named import Named +from .relationship import Relationship +from .text_unit import TextUnit + +__all__ = [ + "Community", + "CommunityReport", + "Covariate", + "Document", + "Entity", + "Identified", + "Named", + "Relationship", + "TextUnit", +] diff --git a/func-app/graphrag/model/community.py b/func-app/graphrag/model/community.py new file mode 100644 index 0000000000..800a9a292a --- /dev/null +++ b/func-app/graphrag/model/community.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Community' model.""" + +from dataclasses import dataclass +from typing import Any + +from .named import Named + + +@dataclass +class Community(Named): + """A protocol for a community in the system.""" + + level: str = "" + """Community level.""" + + entity_ids: list[str] | None = None + """List of entity IDs related to the community (optional).""" + + relationship_ids: list[str] | None = None + """List of relationship IDs related to the community (optional).""" + + covariate_ids: dict[str, list[str]] | None = None + """Dictionary of different types of covariates related to the community (optional), e.g. claims""" + + attributes: dict[str, Any] | None = None + """A dictionary of additional attributes associated with the community (optional). To be included in the search prompt.""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + title_key: str = "title", + short_id_key: str = "short_id", + level_key: str = "level", + entities_key: str = "entity_ids", + relationships_key: str = "relationship_ids", + covariates_key: str = "covariate_ids", + attributes_key: str = "attributes", + ) -> "Community": + """Create a new community from the dict data.""" + return Community( + id=d[id_key], + title=d[title_key], + short_id=d.get(short_id_key), + level=d[level_key], + entity_ids=d.get(entities_key), + relationship_ids=d.get(relationships_key), + covariate_ids=d.get(covariates_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/community_report.py b/func-app/graphrag/model/community_report.py new file mode 100644 index 0000000000..2666c0b5a8 --- /dev/null +++ b/func-app/graphrag/model/community_report.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'CommunityReport' model.""" + +from dataclasses import dataclass +from typing import Any + +from .named import Named + + +@dataclass +class CommunityReport(Named): + """Defines an LLM-generated summary report of a community.""" + + community_id: str + """The ID of the community this report is associated with.""" + + summary: str = "" + """Summary of the report.""" + + full_content: str = "" + """Full content of the report.""" + + rank: float | None = 1.0 + """Rank of the report, used for sorting (optional). Higher means more important""" + + summary_embedding: list[float] | None = None + """The semantic (i.e. text) embedding of the report summary (optional).""" + + full_content_embedding: list[float] | None = None + """The semantic (i.e. text) embedding of the full report content (optional).""" + + attributes: dict[str, Any] | None = None + """A dictionary of additional attributes associated with the report (optional).""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + title_key: str = "title", + community_id_key: str = "community_id", + short_id_key: str = "short_id", + summary_key: str = "summary", + full_content_key: str = "full_content", + rank_key: str = "rank", + summary_embedding_key: str = "summary_embedding", + full_content_embedding_key: str = "full_content_embedding", + attributes_key: str = "attributes", + ) -> "CommunityReport": + """Create a new community report from the dict data.""" + return CommunityReport( + id=d[id_key], + title=d[title_key], + community_id=d[community_id_key], + short_id=d.get(short_id_key), + summary=d[summary_key], + full_content=d[full_content_key], + rank=d[rank_key], + summary_embedding=d.get(summary_embedding_key), + full_content_embedding=d.get(full_content_embedding_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/covariate.py b/func-app/graphrag/model/covariate.py new file mode 100644 index 0000000000..b974b6b327 --- /dev/null +++ b/func-app/graphrag/model/covariate.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Covariate' model.""" + +from dataclasses import dataclass +from typing import Any + +from .identified import Identified + + +@dataclass +class Covariate(Identified): + """ + A protocol for a covariate in the system. + + Covariates are metadata associated with a subject, e.g. entity claims. + Each subject (e.g. entity) may be associated with multiple types of covariates. + """ + + subject_id: str + """The subject id.""" + + subject_type: str = "entity" + """The subject type.""" + + covariate_type: str = "claim" + """The covariate type.""" + + text_unit_ids: list[str] | None = None + """List of text unit IDs in which the covariate info appears (optional).""" + + document_ids: list[str] | None = None + """List of document IDs in which the covariate info appears (optional).""" + + attributes: dict[str, Any] | None = None + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + subject_id_key: str = "subject_id", + subject_type_key: str = "subject_type", + covariate_type_key: str = "covariate_type", + short_id_key: str = "short_id", + text_unit_ids_key: str = "text_unit_ids", + document_ids_key: str = "document_ids", + attributes_key: str = "attributes", + ) -> "Covariate": + """Create a new covariate from the dict data.""" + return Covariate( + id=d[id_key], + short_id=d.get(short_id_key), + subject_id=d[subject_id_key], + subject_type=d.get(subject_type_key, "entity"), + covariate_type=d.get(covariate_type_key, "claim"), + text_unit_ids=d.get(text_unit_ids_key), + document_ids=d.get(document_ids_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/document.py b/func-app/graphrag/model/document.py new file mode 100644 index 0000000000..b54a39ac91 --- /dev/null +++ b/func-app/graphrag/model/document.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Document' model.""" + +from dataclasses import dataclass, field +from typing import Any + +from .named import Named + + +@dataclass +class Document(Named): + """A protocol for a document in the system.""" + + type: str = "text" + """Type of the document.""" + + text_unit_ids: list[str] = field(default_factory=list) + """list of text units in the document.""" + + raw_content: str = "" + """The raw text content of the document.""" + + summary: str | None = None + """Summary of the document (optional).""" + + summary_embedding: list[float] | None = None + """The semantic embedding for the document summary (optional).""" + + raw_content_embedding: list[float] | None = None + """The semantic embedding for the document raw content (optional).""" + + attributes: dict[str, Any] | None = None + """A dictionary of structured attributes such as author, etc (optional).""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + short_id_key: str = "short_id", + title_key: str = "title", + type_key: str = "type", + raw_content_key: str = "raw_content", + summary_key: str = "summary", + summary_embedding_key: str = "summary_embedding", + raw_content_embedding_key: str = "raw_content_embedding", + text_units_key: str = "text_units", + attributes_key: str = "attributes", + ) -> "Document": + """Create a new document from the dict data.""" + return Document( + id=d[id_key], + short_id=d.get(short_id_key), + title=d[title_key], + type=d.get(type_key, "text"), + raw_content=d[raw_content_key], + summary=d.get(summary_key), + summary_embedding=d.get(summary_embedding_key), + raw_content_embedding=d.get(raw_content_embedding_key), + text_unit_ids=d.get(text_units_key, []), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/entity.py b/func-app/graphrag/model/entity.py new file mode 100644 index 0000000000..37c26342aa --- /dev/null +++ b/func-app/graphrag/model/entity.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Entity' model.""" + +from dataclasses import dataclass +from typing import Any + +from .named import Named + + +@dataclass +class Entity(Named): + """A protocol for an entity in the system.""" + + type: str | None = None + """Type of the entity (can be any string, optional).""" + + description: str | None = None + """Description of the entity (optional).""" + + description_embedding: list[float] | None = None + """The semantic (i.e. text) embedding of the entity (optional).""" + + name_embedding: list[float] | None = None + """The semantic (i.e. text) embedding of the entity (optional).""" + + graph_embedding: list[float] | None = None + """The graph embedding of the entity, likely from node2vec (optional).""" + + community_ids: list[str] | None = None + """The community IDs of the entity (optional).""" + + text_unit_ids: list[str] | None = None + """List of text unit IDs in which the entity appears (optional).""" + + document_ids: list[str] | None = None + """List of document IDs in which the entity appears (optional).""" + + rank: int | None = 1 + """Rank of the entity, used for sorting (optional). Higher rank indicates more important entity. This can be based on centrality or other metrics.""" + + attributes: dict[str, Any] | None = None + """Additional attributes associated with the entity (optional), e.g. start time, end time, etc. To be included in the search prompt.""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + short_id_key: str = "short_id", + title_key: str = "title", + type_key: str = "type", + description_key: str = "description", + description_embedding_key: str = "description_embedding", + name_embedding_key: str = "name_embedding", + graph_embedding_key: str = "graph_embedding", + community_key: str = "community", + text_unit_ids_key: str = "text_unit_ids", + document_ids_key: str = "document_ids", + rank_key: str = "degree", + attributes_key: str = "attributes", + ) -> "Entity": + """Create a new entity from the dict data.""" + return Entity( + id=d[id_key], + title=d[title_key], + short_id=d.get(short_id_key), + type=d.get(type_key), + description=d.get(description_key), + name_embedding=d.get(name_embedding_key), + description_embedding=d.get(description_embedding_key), + graph_embedding=d.get(graph_embedding_key), + community_ids=d.get(community_key), + rank=d.get(rank_key, 1), + text_unit_ids=d.get(text_unit_ids_key), + document_ids=d.get(document_ids_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/identified.py b/func-app/graphrag/model/identified.py new file mode 100644 index 0000000000..ca2c939526 --- /dev/null +++ b/func-app/graphrag/model/identified.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Identified' protocol.""" + +from dataclasses import dataclass + + +@dataclass +class Identified: + """A protocol for an item with an ID.""" + + id: str + """The ID of the item.""" + + short_id: str | None + """Human readable ID used to refer to this community in prompts or texts displayed to users, such as in a report text (optional).""" diff --git a/func-app/graphrag/model/named.py b/func-app/graphrag/model/named.py new file mode 100644 index 0000000000..5352c77c96 --- /dev/null +++ b/func-app/graphrag/model/named.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Named' protocol.""" + +from dataclasses import dataclass + +from .identified import Identified + + +@dataclass +class Named(Identified): + """A protocol for an item with a name/title.""" + + title: str + """The name/title of the item.""" diff --git a/func-app/graphrag/model/relationship.py b/func-app/graphrag/model/relationship.py new file mode 100644 index 0000000000..fadd0aaa6f --- /dev/null +++ b/func-app/graphrag/model/relationship.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Relationship' model.""" + +from dataclasses import dataclass +from typing import Any + +from .identified import Identified + + +@dataclass +class Relationship(Identified): + """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" + + source: str + """The source entity name.""" + + target: str + """The target entity name.""" + + weight: float | None = 1.0 + """The edge weight.""" + + description: str | None = None + """A description of the relationship (optional).""" + + description_embedding: list[float] | None = None + """The semantic embedding for the relationship description (optional).""" + + text_unit_ids: list[str] | None = None + """List of text unit IDs in which the relationship appears (optional).""" + + document_ids: list[str] | None = None + """List of document IDs in which the relationship appears (optional).""" + + attributes: dict[str, Any] | None = None + """Additional attributes associated with the relationship (optional). To be included in the search prompt""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + short_id_key: str = "short_id", + source_key: str = "source", + target_key: str = "target", + description_key: str = "description", + weight_key: str = "weight", + text_unit_ids_key: str = "text_unit_ids", + document_ids_key: str = "document_ids", + attributes_key: str = "attributes", + ) -> "Relationship": + """Create a new relationship from the dict data.""" + return Relationship( + id=d[id_key], + short_id=d.get(short_id_key), + source=d[source_key], + target=d[target_key], + description=d.get(description_key), + weight=d.get(weight_key, 1.0), + text_unit_ids=d.get(text_unit_ids_key), + document_ids=d.get(document_ids_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/text_unit.py b/func-app/graphrag/model/text_unit.py new file mode 100644 index 0000000000..cff4ac01c1 --- /dev/null +++ b/func-app/graphrag/model/text_unit.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'TextUnit' model.""" + +from dataclasses import dataclass +from typing import Any + +from .identified import Identified + + +@dataclass +class TextUnit(Identified): + """A protocol for a TextUnit item in a Document database.""" + + text: str + """The text of the unit.""" + + text_embedding: list[float] | None = None + """The text embedding for the text unit (optional).""" + + entity_ids: list[str] | None = None + """List of entity IDs related to the text unit (optional).""" + + relationship_ids: list[str] | None = None + """List of relationship IDs related to the text unit (optional).""" + + covariate_ids: dict[str, list[str]] | None = None + "Dictionary of different types of covariates related to the text unit (optional)." + + n_tokens: int | None = None + """The number of tokens in the text (optional).""" + + document_ids: list[str] | None = None + """List of document IDs in which the text unit appears (optional).""" + + attributes: dict[str, Any] | None = None + """A dictionary of additional attributes associated with the text unit (optional).""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + short_id_key: str = "short_id", + text_key: str = "text", + text_embedding_key: str = "text_embedding", + entities_key: str = "entity_ids", + relationships_key: str = "relationship_ids", + covariates_key: str = "covariate_ids", + n_tokens_key: str = "n_tokens", + document_ids_key: str = "document_ids", + attributes_key: str = "attributes", + ) -> "TextUnit": + """Create a new text unit from the dict data.""" + return TextUnit( + id=d[id_key], + short_id=d.get(short_id_key), + text=d[text_key], + text_embedding=d.get(text_embedding_key), + entity_ids=d.get(entities_key), + relationship_ids=d.get(relationships_key), + covariate_ids=d.get(covariates_key), + n_tokens=d.get(n_tokens_key), + document_ids=d.get(document_ids_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/types.py b/func-app/graphrag/model/types.py new file mode 100644 index 0000000000..6156e39969 --- /dev/null +++ b/func-app/graphrag/model/types.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Common types for the GraphRAG knowledge model.""" + +from collections.abc import Callable + +TextEmbedder = Callable[[str], list[float]] diff --git a/func-app/graphrag/prompt_tune/__init__.py b/func-app/graphrag/prompt_tune/__init__.py new file mode 100644 index 0000000000..2384b5793c --- /dev/null +++ b/func-app/graphrag/prompt_tune/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Command line interface for the fine_tune module.""" diff --git a/func-app/graphrag/prompt_tune/__main__.py b/func-app/graphrag/prompt_tune/__main__.py new file mode 100644 index 0000000000..e752b05a8f --- /dev/null +++ b/func-app/graphrag/prompt_tune/__main__.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Prompt auto templating package root.""" + +import argparse +import asyncio +from enum import Enum + +from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT +from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE + +from .cli import prompt_tune + + +class DocSelectionType(Enum): + """The type of document selection to use.""" + + ALL = "all" + RANDOM = "random" + TOP = "top" + AUTO = "auto" + + def __str__(self): + """Return the string representation of the enum value.""" + return self.value + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--root", + help="The data project root. Including the config yml, json or .env", + required=False, + type=str, + default=".", + ) + + parser.add_argument( + "--domain", + help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If left empty, the domain will be inferred from the input data.", + required=False, + default="", + type=str, + ) + + parser.add_argument( + "--method", + help="The method to select documents, one of: all, random, top or auto", + required=False, + type=DocSelectionType, + choices=list(DocSelectionType), + default=DocSelectionType.RANDOM, + ) + + parser.add_argument( + "--n_subset_max", + help="The number of text chunks to embed when using auto selection method", + required=False, + type=int, + default=300, + ) + + parser.add_argument( + "--k", + help="The maximum number of documents to select from each centroid when using auto selection method", + required=False, + type=int, + default=15, + ) + + parser.add_argument( + "--limit", + help="The limit of files to load when doing random or top selection", + type=int, + required=False, + default=15, + ) + + parser.add_argument( + "--max-tokens", + help="Max token count for prompt generation", + type=int, + required=False, + default=MAX_TOKEN_COUNT, + ) + + parser.add_argument( + "--min-examples-required", + help="The minimum number of examples required in entity extraction prompt", + type=int, + required=False, + default=2, + ) + + parser.add_argument( + "--chunk-size", + help="Max token count for prompt generation", + type=int, + required=False, + default=MIN_CHUNK_SIZE, + ) + + parser.add_argument( + "--language", + help="Primary language used for inputs and outputs on GraphRAG", + type=str, + required=False, + default="", + ) + + parser.add_argument( + "--no-entity-types", + help="Use untyped entity extraction generation", + action="store_true", + required=False, + default=False, + ) + + parser.add_argument( + "--output", + help="Folder to save the generated prompts to", + type=str, + required=False, + default="prompts", + ) + + args = parser.parse_args() + + loop = asyncio.get_event_loop() + + loop.run_until_complete( + prompt_tune( + args.root, + args.domain, + str(args.method), + args.limit, + args.max_tokens, + args.chunk_size, + args.language, + args.no_entity_types, + args.output, + args.n_subset_max, + args.k, + args.min_examples_required, + ) + ) diff --git a/func-app/graphrag/prompt_tune/cli.py b/func-app/graphrag/prompt_tune/cli.py new file mode 100644 index 0000000000..5979a4a6ee --- /dev/null +++ b/func-app/graphrag/prompt_tune/cli.py @@ -0,0 +1,272 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Command line interface for the fine_tune module.""" + +from pathlib import Path + +from datashaper import NoopVerbCallbacks + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.llm import load_llm +from graphrag.index.progress import PrintProgressReporter +from graphrag.index.progress.types import ProgressReporter +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.generator import ( + MAX_TOKEN_COUNT, + create_community_summarization_prompt, + create_entity_extraction_prompt, + create_entity_summarization_prompt, + detect_language, + generate_community_report_rating, + generate_community_reporter_role, + generate_domain, + generate_entity_relationship_examples, + generate_entity_types, + generate_persona, +) +from graphrag.prompt_tune.loader import ( + MIN_CHUNK_SIZE, + load_docs_in_chunks, + read_config_parameters, +) + + +async def prompt_tune( + root: str, + domain: str, + select: str = "random", + limit: int = 15, + max_tokens: int = MAX_TOKEN_COUNT, + chunk_size: int = MIN_CHUNK_SIZE, + language: str | None = None, + skip_entity_types: bool = False, + output: str = "prompts", + n_subset_max: int = 300, + k: int = 15, + min_examples_required: int = 2, +): + """Prompt tune the model. + + Parameters + ---------- + - root: The root directory. + - domain: The domain to map the input documents to. + - select: The chunk selection method. + - limit: The limit of chunks to load. + - max_tokens: The maximum number of tokens to use on entity extraction prompts. + - chunk_size: The chunk token size to use. + - skip_entity_types: Skip generating entity types. + - output: The output folder to store the prompts. + - n_subset_max: The number of text chunks to embed when using auto selection method. + - k: The number of documents to select when using auto selection method. + """ + reporter = PrintProgressReporter("") + config = read_config_parameters(root, reporter) + + await prompt_tune_with_config( + root, + config, + domain, + select, + limit, + max_tokens, + chunk_size, + language, + skip_entity_types, + output, + reporter, + n_subset_max, + k, + min_examples_required, + ) + + +async def prompt_tune_with_config( + root: str, + config: GraphRagConfig, + domain: str, + select: str = "random", + limit: int = 15, + max_tokens: int = MAX_TOKEN_COUNT, + chunk_size: int = MIN_CHUNK_SIZE, + language: str | None = None, + skip_entity_types: bool = False, + output: str = "prompts", + reporter: ProgressReporter | None = None, + n_subset_max: int = 300, + k: int = 15, + min_examples_required: int = 2, +): + """Prompt tune the model with a configuration. + + Parameters + ---------- + - root: The root directory. + - config: The GraphRag configuration. + - domain: The domain to map the input documents to. + - select: The chunk selection method. + - limit: The limit of chunks to load. + - max_tokens: The maximum number of tokens to use on entity extraction prompts. + - chunk_size: The chunk token size to use for input text units. + - skip_entity_types: Skip generating entity types. + - output: The output folder to store the prompts. + - reporter: The progress reporter. + - n_subset_max: The number of text chunks to embed when using auto selection method. + - k: The number of documents to select when using auto selection method. + + Returns + ------- + - None + """ + if not reporter: + reporter = PrintProgressReporter("") + + output_path = Path(config.root_dir) / output + + doc_list = await load_docs_in_chunks( + root=root, + config=config, + limit=limit, + select_method=select, + reporter=reporter, + chunk_size=chunk_size, + n_subset_max=n_subset_max, + k=k, + ) + + # Create LLM from config + llm = load_llm( + "prompt_tuning", + config.llm.type, + NoopVerbCallbacks(), + None, + config.llm.model_dump(), + ) + + await generate_indexing_prompts( + llm, + config, + doc_list, + output_path, + reporter, + domain, + language, + max_tokens, + skip_entity_types, + min_examples_required, + ) + + +async def generate_indexing_prompts( + llm: CompletionLLM, + config: GraphRagConfig, + doc_list: list[str], + output_path: Path, + reporter: ProgressReporter, + domain: str | None = None, + language: str | None = None, + max_tokens: int = MAX_TOKEN_COUNT, + skip_entity_types: bool = False, + min_examples_required: int = 2, +): + """Generate indexing prompts. + + Parameters + ---------- + - llm: The LLM model to use. + - config: The GraphRag configuration. + - doc_list: The list of documents to use. + - output_path: The path to store the prompts. + - reporter: The progress reporter. + - domain: The domain to map the input documents to. + - max_tokens: The maximum number of tokens to use on entity extraction prompts + - skip_entity_types: Skip generating entity types. + - min_examples_required: The minimum number of examples required for entity extraction prompts. + """ + if not domain: + reporter.info("Generating domain...") + domain = await generate_domain(llm, doc_list) + reporter.info(f"Generated domain: {domain}") + + if not language: + reporter.info("Detecting language...") + language = await detect_language(llm, doc_list) + reporter.info(f"Detected language: {language}") + + reporter.info("Generating persona...") + persona = await generate_persona(llm, domain) + reporter.info(f"Generated persona: {persona}") + + reporter.info("Generating community report ranking description...") + community_report_ranking = await generate_community_report_rating( + llm, domain=domain, persona=persona, docs=doc_list + ) + reporter.info( + f"Generated community report ranking description: {community_report_ranking}" + ) + + entity_types = None + if not skip_entity_types: + reporter.info("Generating entity types") + entity_types = await generate_entity_types( + llm, + domain=domain, + persona=persona, + docs=doc_list, + json_mode=config.llm.model_supports_json or False, + ) + reporter.info(f"Generated entity types: {entity_types}") + + reporter.info("Generating entity relationship examples...") + examples = await generate_entity_relationship_examples( + llm, + persona=persona, + entity_types=entity_types, + docs=doc_list, + language=language, + json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine + ) + reporter.info("Done generating entity relationship examples") + + reporter.info("Generating entity extraction prompt...") + create_entity_extraction_prompt( + entity_types=entity_types, + docs=doc_list, + examples=examples, + language=language, + json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine + output_path=output_path, + encoding_model=config.encoding_model, + max_token_count=max_tokens, + min_examples_required=min_examples_required, + ) + reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}") + + reporter.info("Generating entity summarization prompt...") + create_entity_summarization_prompt( + persona=persona, + language=language, + output_path=output_path, + ) + reporter.info( + f"Generated entity summarization prompt, stored in folder {output_path}" + ) + + reporter.info("Generating community reporter role...") + community_reporter_role = await generate_community_reporter_role( + llm, domain=domain, persona=persona, docs=doc_list + ) + reporter.info(f"Generated community reporter role: {community_reporter_role}") + + reporter.info("Generating community summarization prompt...") + create_community_summarization_prompt( + persona=persona, + role=community_reporter_role, + report_rating_description=community_report_ranking, + language=language, + output_path=output_path, + ) + reporter.info( + f"Generated community summarization prompt, stored in folder {output_path}" + ) diff --git a/func-app/graphrag/prompt_tune/generator/__init__.py b/func-app/graphrag/prompt_tune/generator/__init__.py new file mode 100644 index 0000000000..df45b46033 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Prompt generation module.""" + +from .community_report_rating import generate_community_report_rating +from .community_report_summarization import create_community_summarization_prompt +from .community_reporter_role import generate_community_reporter_role +from .defaults import MAX_TOKEN_COUNT +from .domain import generate_domain +from .entity_extraction_prompt import create_entity_extraction_prompt +from .entity_relationship import generate_entity_relationship_examples +from .entity_summarization_prompt import create_entity_summarization_prompt +from .entity_types import generate_entity_types +from .language import detect_language +from .persona import generate_persona + +__all__ = [ + "MAX_TOKEN_COUNT", + "create_community_summarization_prompt", + "create_entity_extraction_prompt", + "create_entity_summarization_prompt", + "detect_language", + "generate_community_report_rating", + "generate_community_reporter_role", + "generate_domain", + "generate_entity_relationship_examples", + "generate_entity_types", + "generate_persona", +] diff --git a/func-app/graphrag/prompt_tune/generator/community_report_rating.py b/func-app/graphrag/prompt_tune/generator/community_report_rating.py new file mode 100644 index 0000000000..59f94d5698 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/community_report_rating.py @@ -0,0 +1,35 @@ +"""Generate a rating description for community report rating.""" + +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.prompt import ( + GENERATE_REPORT_RATING_PROMPT, +) + + +async def generate_community_report_rating( + llm: CompletionLLM, domain: str, persona: str, docs: str | list[str] +) -> str: + """Generate an LLM persona to use for GraphRAG prompts. + + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - domain (str): The domain to generate a rating for + - persona (str): The persona to generate a rating for for + - docs (str | list[str]): Documents used to contextualize the rating + + Returns + ------- + - str: The generated rating description prompt response. + """ + docs_str = " ".join(docs) if isinstance(docs, list) else docs + domain_prompt = GENERATE_REPORT_RATING_PROMPT.format( + domain=domain, persona=persona, input_text=docs_str + ) + + response = await llm(domain_prompt) + + return str(response.output).strip() diff --git a/func-app/graphrag/prompt_tune/generator/community_report_summarization.py b/func-app/graphrag/prompt_tune/generator/community_report_summarization.py new file mode 100644 index 0000000000..b0c0b614d2 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/community_report_summarization.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Module for generating prompts for community report summarization.""" + +from pathlib import Path + +from graphrag.prompt_tune.template import COMMUNITY_REPORT_SUMMARIZATION_PROMPT + +COMMUNITY_SUMMARIZATION_FILENAME = "community_report.txt" + + +def create_community_summarization_prompt( + persona: str, + role: str, + report_rating_description: str, + language: str, + output_path: Path | None = None, +) -> str: + """Create a prompt for community summarization. If output_path is provided, write the prompt to a file. + + Parameters + ---------- + - persona (str): The persona to use for the community summarization prompt + - role (str): The role to use for the community summarization prompt + - language (str): The language to use for the community summarization prompt + - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + + Returns + ------- + - str: The community summarization prompt + """ + prompt = COMMUNITY_REPORT_SUMMARIZATION_PROMPT.format( + persona=persona, + role=role, + report_rating_description=report_rating_description, + language=language, + ) + + if output_path: + output_path.mkdir(parents=True, exist_ok=True) + + output_path = output_path / COMMUNITY_SUMMARIZATION_FILENAME + # Write file to output path + with output_path.open("wb") as file: + file.write(prompt.encode(encoding="utf-8", errors="strict")) + + return prompt diff --git a/func-app/graphrag/prompt_tune/generator/community_reporter_role.py b/func-app/graphrag/prompt_tune/generator/community_reporter_role.py new file mode 100644 index 0000000000..9abd5ed83f --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/community_reporter_role.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Generate a community reporter role for community summarization.""" + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.prompt import ( + GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT, +) + + +async def generate_community_reporter_role( + llm: CompletionLLM, domain: str, persona: str, docs: str | list[str] +) -> str: + """Generate an LLM persona to use for GraphRAG prompts. + + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - domain (str): The domain to generate a persona for + - persona (str): The persona to generate a role for + - docs (str | list[str]): The domain to generate a persona for + + Returns + ------- + - str: The generated domain prompt response. + """ + docs_str = " ".join(docs) if isinstance(docs, list) else docs + domain_prompt = GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT.format( + domain=domain, persona=persona, input_text=docs_str + ) + + response = await llm(domain_prompt) + + return str(response.output) diff --git a/func-app/graphrag/prompt_tune/generator/defaults.py b/func-app/graphrag/prompt_tune/generator/defaults.py new file mode 100644 index 0000000000..5b42f81332 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/defaults.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Default values for the fine-tuning module.""" + +DEFAULT_TASK = """ +Identify the relations and structure of the community of interest, specifically within the {domain} domain. +""" + +MAX_TOKEN_COUNT = 2000 diff --git a/func-app/graphrag/prompt_tune/generator/domain.py b/func-app/graphrag/prompt_tune/generator/domain.py new file mode 100644 index 0000000000..49c698d1b4 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/domain.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Domain generation for GraphRAG prompts.""" + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.prompt.domain import GENERATE_DOMAIN_PROMPT + + +async def generate_domain(llm: CompletionLLM, docs: str | list[str]) -> str: + """Generate an LLM persona to use for GraphRAG prompts. + + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - docs (str | list[str]): The domain to generate a persona for + + Returns + ------- + - str: The generated domain prompt response. + """ + docs_str = " ".join(docs) if isinstance(docs, list) else docs + domain_prompt = GENERATE_DOMAIN_PROMPT.format(input_text=docs_str) + + response = await llm(domain_prompt) + + return str(response.output) diff --git a/func-app/graphrag/prompt_tune/generator/entity_extraction_prompt.py b/func-app/graphrag/prompt_tune/generator/entity_extraction_prompt.py new file mode 100644 index 0000000000..3b17dbab5d --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/entity_extraction_prompt.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Entity Extraction prompt generator module.""" + +from pathlib import Path + +import graphrag.config.defaults as defs +from graphrag.index.utils.tokens import num_tokens_from_string +from graphrag.prompt_tune.template import ( + EXAMPLE_EXTRACTION_TEMPLATE, + GRAPH_EXTRACTION_JSON_PROMPT, + GRAPH_EXTRACTION_PROMPT, + UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE, + UNTYPED_GRAPH_EXTRACTION_PROMPT, +) + +ENTITY_EXTRACTION_FILENAME = "entity_extraction.txt" + + +def create_entity_extraction_prompt( + entity_types: str | list[str] | None, + docs: list[str], + examples: list[str], + language: str, + max_token_count: int, + encoding_model: str = defs.ENCODING_MODEL, + json_mode: bool = False, + output_path: Path | None = None, + min_examples_required: int = 2, +) -> str: + """ + Create a prompt for entity extraction. + + Parameters + ---------- + - entity_types (str | list[str]): The entity types to extract + - docs (list[str]): The list of documents to extract entities from + - examples (list[str]): The list of examples to use for entity extraction + - language (str): The language of the inputs and outputs + - encoding_model (str): The name of the model to use for token counting + - max_token_count (int): The maximum number of tokens to use for the prompt + - json_mode (bool): Whether to use JSON mode for the prompt. Default is False + - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + - min_examples_required (int): The minimum number of examples required. Default is 2. + + Returns + ------- + - str: The entity extraction prompt + """ + prompt = ( + (GRAPH_EXTRACTION_JSON_PROMPT if json_mode else GRAPH_EXTRACTION_PROMPT) + if entity_types + else UNTYPED_GRAPH_EXTRACTION_PROMPT + ) + if isinstance(entity_types, list): + entity_types = ", ".join(entity_types) + + tokens_left = ( + max_token_count + - num_tokens_from_string(prompt, model=encoding_model) + - num_tokens_from_string(entity_types, model=encoding_model) + if entity_types + else 0 + ) + + examples_prompt = "" + + # Iterate over examples, while we have tokens left or examples left + for i, output in enumerate(examples): + input = docs[i] + example_formatted = ( + EXAMPLE_EXTRACTION_TEMPLATE.format( + n=i + 1, input_text=input, entity_types=entity_types, output=output + ) + if entity_types + else UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE.format( + n=i + 1, input_text=input, output=output + ) + ) + + example_tokens = num_tokens_from_string(example_formatted, model=encoding_model) + + # Ensure at least three examples are included + if i >= min_examples_required and example_tokens > tokens_left: + break + + examples_prompt += example_formatted + tokens_left -= example_tokens + + prompt = ( + prompt.format( + entity_types=entity_types, examples=examples_prompt, language=language + ) + if entity_types + else prompt.format(examples=examples_prompt, language=language) + ) + + if output_path: + output_path.mkdir(parents=True, exist_ok=True) + + output_path = output_path / ENTITY_EXTRACTION_FILENAME + # Write file to output path + with output_path.open("wb") as file: + file.write(prompt.encode(encoding="utf-8", errors="strict")) + + return prompt diff --git a/func-app/graphrag/prompt_tune/generator/entity_relationship.py b/func-app/graphrag/prompt_tune/generator/entity_relationship.py new file mode 100644 index 0000000000..72ecb5f4da --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/entity_relationship.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Entity relationship example generation module.""" + +import asyncio +import json + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.prompt import ( + ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT, + ENTITY_RELATIONSHIPS_GENERATION_PROMPT, + UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT, +) + +MAX_EXAMPLES = 5 + + +async def generate_entity_relationship_examples( + llm: CompletionLLM, + persona: str, + entity_types: str | list[str] | None, + docs: str | list[str], + language: str, + json_mode: bool = False, +) -> list[str]: + """Generate a list of entity/relationships examples for use in generating an entity configuration. + + Will return entity/relationships examples as either JSON or in tuple_delimiter format depending + on the json_mode parameter. + """ + docs_list = [docs] if isinstance(docs, str) else docs + history = [{"role": "system", "content": persona}] + + if entity_types: + entity_types_str = ( + entity_types if isinstance(entity_types, str) else ", ".join(entity_types) + ) + + messages = [ + ( + ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT + if json_mode + else ENTITY_RELATIONSHIPS_GENERATION_PROMPT + ).format(entity_types=entity_types_str, input_text=doc, language=language) + for doc in docs_list + ] + else: + messages = [ + UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT.format( + input_text=doc, language=language + ) + for doc in docs_list + ] + + messages = messages[:MAX_EXAMPLES] + + tasks = [llm(message, history=history, json=json_mode) for message in messages] + + responses = await asyncio.gather(*tasks) + + return [ + json.dumps(response.json or "") if json_mode else str(response.output) + for response in responses + ] diff --git a/func-app/graphrag/prompt_tune/generator/entity_summarization_prompt.py b/func-app/graphrag/prompt_tune/generator/entity_summarization_prompt.py new file mode 100644 index 0000000000..4ae5af77ec --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/entity_summarization_prompt.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Entity summarization prompt generation module.""" + +from pathlib import Path + +from graphrag.prompt_tune.template import ENTITY_SUMMARIZATION_PROMPT + +ENTITY_SUMMARIZATION_FILENAME = "summarize_descriptions.txt" + + +def create_entity_summarization_prompt( + persona: str, + language: str, + output_path: Path | None = None, +) -> str: + """Create a prompt for entity summarization. If output_path is provided, write the prompt to a file. + + Parameters + ---------- + - persona (str): The persona to use for the entity summarization prompt + - language (str): The language to use for the entity summarization prompt + - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + """ + prompt = ENTITY_SUMMARIZATION_PROMPT.format(persona=persona, language=language) + + if output_path: + output_path.mkdir(parents=True, exist_ok=True) + + output_path = output_path / ENTITY_SUMMARIZATION_FILENAME + # Write file to output path + with output_path.open("wb") as file: + file.write(prompt.encode(encoding="utf-8", errors="strict")) + + return prompt diff --git a/func-app/graphrag/prompt_tune/generator/entity_types.py b/func-app/graphrag/prompt_tune/generator/entity_types.py new file mode 100644 index 0000000000..42518acd8c --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/entity_types.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Entity type generation module for fine-tuning.""" + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.generator.defaults import DEFAULT_TASK +from graphrag.prompt_tune.prompt.entity_types import ( + ENTITY_TYPE_GENERATION_JSON_PROMPT, + ENTITY_TYPE_GENERATION_PROMPT, +) + + +async def generate_entity_types( + llm: CompletionLLM, + domain: str, + persona: str, + docs: str | list[str], + task: str = DEFAULT_TASK, + json_mode: bool = False, +) -> str | list[str]: + """ + Generate entity type categories from a given set of (small) documents. + + Example Output: + "entity_types": ['military unit', 'organization', 'person', 'location', 'event', 'date', 'equipment'] + """ + formatted_task = task.format(domain=domain) + + docs_str = "\n".join(docs) if isinstance(docs, list) else docs + + entity_types_prompt = ( + ENTITY_TYPE_GENERATION_JSON_PROMPT + if json_mode + else ENTITY_TYPE_GENERATION_PROMPT + ).format(task=formatted_task, input_text=docs_str) + + history = [{"role": "system", "content": persona}] + + response = await llm(entity_types_prompt, history=history, json=json_mode) + + if json_mode: + return (response.json or {}).get("entity_types", []) + + return str(response.output) diff --git a/func-app/graphrag/prompt_tune/generator/language.py b/func-app/graphrag/prompt_tune/generator/language.py new file mode 100644 index 0000000000..38de531ca3 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/language.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Language detection for GraphRAG prompts.""" + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.prompt import DETECT_LANGUAGE_PROMPT + + +async def detect_language(llm: CompletionLLM, docs: str | list[str]) -> str: + """Detect input language to use for GraphRAG prompts. + + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - docs (str | list[str]): The docs to detect language from + + Returns + ------- + - str: The detected language. + """ + docs_str = " ".join(docs) if isinstance(docs, list) else docs + language_prompt = DETECT_LANGUAGE_PROMPT.format(input_text=docs_str) + + response = await llm(language_prompt) + + return str(response.output) diff --git a/func-app/graphrag/prompt_tune/generator/persona.py b/func-app/graphrag/prompt_tune/generator/persona.py new file mode 100644 index 0000000000..cdd57a655d --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/persona.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Persona generating module for fine-tuning GraphRAG prompts.""" + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.generator.defaults import DEFAULT_TASK +from graphrag.prompt_tune.prompt import GENERATE_PERSONA_PROMPT + + +async def generate_persona( + llm: CompletionLLM, domain: str, task: str = DEFAULT_TASK +) -> str: + """Generate an LLM persona to use for GraphRAG prompts. + + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - domain (str): The domain to generate a persona for + - task (str): The task to generate a persona for. Default is DEFAULT_TASK + """ + formatted_task = task.format(domain=domain) + persona_prompt = GENERATE_PERSONA_PROMPT.format(sample_task=formatted_task) + + response = await llm(persona_prompt) + + return str(response.output) diff --git a/func-app/graphrag/prompt_tune/loader/__init__.py b/func-app/graphrag/prompt_tune/loader/__init__.py new file mode 100644 index 0000000000..94e64cbe87 --- /dev/null +++ b/func-app/graphrag/prompt_tune/loader/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning config and data loader module.""" + +from .config import read_config_parameters +from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks + +__all__ = [ + "MIN_CHUNK_OVERLAP", + "MIN_CHUNK_SIZE", + "load_docs_in_chunks", + "read_config_parameters", +] diff --git a/func-app/graphrag/prompt_tune/loader/config.py b/func-app/graphrag/prompt_tune/loader/config.py new file mode 100644 index 0000000000..8994604f92 --- /dev/null +++ b/func-app/graphrag/prompt_tune/loader/config.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Config loading, parsing and handling module.""" + +from pathlib import Path + +from graphrag.config import create_graphrag_config +from graphrag.index.progress.types import ProgressReporter + + +def read_config_parameters(root: str, reporter: ProgressReporter): + """Read the configuration parameters from the settings file or environment variables. + + Parameters + ---------- + - root: The root directory where the parameters are. + - reporter: The progress reporter. + """ + _root = Path(root) + settings_yaml = _root / "settings.yaml" + if not settings_yaml.exists(): + settings_yaml = _root / "settings.yml" + settings_json = _root / "settings.json" + + if settings_yaml.exists(): + reporter.info(f"Reading settings from {settings_yaml}") + with settings_yaml.open("rb") as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + if settings_json.exists(): + reporter.info(f"Reading settings from {settings_json}") + with settings_json.open("rb") as file: + import json + + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + reporter.info("Reading settings from environment variables") + return create_graphrag_config(root_dir=root) diff --git a/func-app/graphrag/prompt_tune/loader/input.py b/func-app/graphrag/prompt_tune/loader/input.py new file mode 100644 index 0000000000..86c4a76040 --- /dev/null +++ b/func-app/graphrag/prompt_tune/loader/input.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Input loading module.""" + +from typing import cast + +import numpy as np +import pandas as pd +from datashaper import NoopVerbCallbacks, TableContainer, VerbInput + +import graphrag.config.defaults as defs +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.input import load_input +from graphrag.index.llm import load_llm_embeddings +from graphrag.index.progress.types import ProgressReporter +from graphrag.index.verbs import chunk +from graphrag.llm.types.llm_types import EmbeddingLLM + +MIN_CHUNK_OVERLAP = 0 +MIN_CHUNK_SIZE = 200 +N_SUBSET_MAX = 300 +K = 15 + + +async def _embed_chunks( + text_chunks: pd.DataFrame, + embedding_llm: EmbeddingLLM, + n_subset_max: int = N_SUBSET_MAX, +) -> tuple[pd.DataFrame, np.ndarray]: + """Convert text chunks into dense text embeddings.""" + sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks))) + embeddings = await embedding_llm(sampled_text_chunks["chunks"].tolist()) + return text_chunks, np.array(embeddings.output) + + +def _sample_chunks_from_embeddings( + text_chunks: pd.DataFrame, + embeddings, + k: int = K, +) -> pd.DataFrame: + """Sample text chunks from embeddings.""" + center = np.mean(embeddings, axis=0) + distances = np.linalg.norm(embeddings - center, axis=1) + nearest_indices = np.argsort(distances)[:k] + + return text_chunks.iloc[nearest_indices] + + +async def load_docs_in_chunks( + root: str, + config: GraphRagConfig, + select_method: str, + limit: int, + reporter: ProgressReporter, + chunk_size: int = MIN_CHUNK_SIZE, + n_subset_max: int = N_SUBSET_MAX, + k: int = K, +) -> list[str]: + """Load docs into chunks for generating prompts.""" + dataset = await load_input(config.input, reporter, root) + + # covert to text units + input = VerbInput(input=TableContainer(table=dataset)) + chunk_strategy = config.chunks.resolved_strategy(defs.ENCODING_MODEL) + + # Use smaller chunks, to avoid huge prompts + chunk_strategy["chunk_size"] = chunk_size + chunk_strategy["chunk_overlap"] = MIN_CHUNK_OVERLAP + + dataset_chunks_table_container = chunk( + input, + column="text", + to="chunks", + callbacks=NoopVerbCallbacks(), + strategy=chunk_strategy, + ) + + dataset_chunks = cast(pd.DataFrame, dataset_chunks_table_container.table) + + # Select chunks into a new df and explode it + chunks_df = pd.DataFrame(dataset_chunks["chunks"].explode()) # type: ignore + + # Depending on the select method, build the dataset + if limit <= 0 or limit > len(chunks_df): + limit = len(chunks_df) + + if select_method == "top": + chunks_df = chunks_df[:limit] + elif select_method == "random": + chunks_df = chunks_df.sample(n=limit) + elif select_method == "auto": + if k is None or k <= 0: + msg = "k must be an integer > 0" + raise ValueError(msg) + embedding_llm = load_llm_embeddings( + name="prompt_tuning_embeddings", + llm_type=config.embeddings.resolved_strategy()["llm"]["type"], + callbacks=NoopVerbCallbacks(), + cache=None, + llm_config=config.embeddings.resolved_strategy()["llm"], + ) + + chunks_df, embeddings = await _embed_chunks( + chunks_df, embedding_llm, n_subset_max=n_subset_max + ) + chunks_df = _sample_chunks_from_embeddings(chunks_df, embeddings, k=k) + + # Convert the dataset to list form, so we have a list of documents + return chunks_df["chunks"].tolist() diff --git a/func-app/graphrag/prompt_tune/prompt/__init__.py b/func-app/graphrag/prompt_tune/prompt/__init__.py new file mode 100644 index 0000000000..991d52856e --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/__init__.py @@ -0,0 +1,32 @@ +"""Persona, entity type, relationships and domain generation prompts module.""" + +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from .community_report_rating import GENERATE_REPORT_RATING_PROMPT +from .community_reporter_role import GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT +from .domain import GENERATE_DOMAIN_PROMPT +from .entity_relationship import ( + ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT, + ENTITY_RELATIONSHIPS_GENERATION_PROMPT, + UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT, +) +from .entity_types import ( + ENTITY_TYPE_GENERATION_JSON_PROMPT, + ENTITY_TYPE_GENERATION_PROMPT, +) +from .language import DETECT_LANGUAGE_PROMPT +from .persona import GENERATE_PERSONA_PROMPT + +__all__ = [ + "DETECT_LANGUAGE_PROMPT", + "ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT", + "ENTITY_RELATIONSHIPS_GENERATION_PROMPT", + "ENTITY_TYPE_GENERATION_JSON_PROMPT", + "ENTITY_TYPE_GENERATION_PROMPT", + "GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT", + "GENERATE_DOMAIN_PROMPT", + "GENERATE_PERSONA_PROMPT", + "GENERATE_REPORT_RATING_PROMPT", + "UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT", +] diff --git a/func-app/graphrag/prompt_tune/prompt/community_report_rating.py b/func-app/graphrag/prompt_tune/prompt/community_report_rating.py new file mode 100644 index 0000000000..b061645b94 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/community_report_rating.py @@ -0,0 +1,132 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine tuning prompts for Community Reports Rating.""" + +GENERATE_REPORT_RATING_PROMPT = """ + +You are a helpful agent tasked with rating the importance of a given text in the context of the provided domain and persona. Your goal is to provide a rating that reflects the relevance and significance of the text to the specified domain and persona. Use your expertise to evaluate the text based on the importance criteria and assign a float score between 0-10. Only respond with the text description of the importance criteria. Use the provided example data format to guide your response. Ignore the content of the example data and focus on the structure. + +###################### +-Examples- +###################### + +### Example 1 + +# Domain + +Personal and Family Communication + +# Persona + +You are an expert in Social Network Analysis with a focus on the Personal and Family Communication domain. You are skilled at mapping and interpreting complex social networks, understanding the dynamics of interpersonal relationships, and identifying patterns of communication within communities. You are adept at helping people understand the structure and relations within their personal and family networks, providing insights into how information flows, how strong various connections are, and how these networks influence individual and group behavior. + +# Data + + +Subject: Re: Event +From: Alice Brown alice.brown@example.com +Date: 2012-11-14, 9:52 a.m. +To: John Smith john.smith@example.com +CC: Jane Doe jane.doe@example.com, Bob Johnson bob.johnson@example.com, Emma Davis emma.davis@example.com + +The event is at 6pm at City Hall (Queen street) event chamber. We +just need to get there by 5:45pm. It is 30-minute long so we will be +done by 6:30pm. We'll then head over to New Sky on Spadina for some +unique cuisine! + +Guests are you and Emma, and my uncle and auntie from London +who my folks have designated to act as their reps. Jane and Joe are +witnesses. + +Be there or be square! +Alice + +On Wed, Nov 14, 2012 at 9:40 AM, John Smith john.smith@example.com wrote: + +Thats the day after Bob's event! +Any more details on the event schedule? ITS NEXT WEEK! +On Tue, Nov 13, 2012 at 7:51 PM, Jane Doe +jane.doe@example.com wrote: +I am supposed to forward you the invitation to this year's celebration. +Date: Saturday, Nov. 24, 6 pm starting +Place as usual: Dean's house, 6 Cardish, Kleinburg L0J 1C0 +Jane Doe +jane.doe@example.com + +# Importance Criteria + +A float score between 0-10 that represents the relevance of the email's content to family communication, health concerns, travel plans, and interpersonal dynamics, with 1 being trivial or spam and 10 being highly relevant, urgent, and impactful to family cohesion or well-being. +############################# + +### Example 2 + +# Domain + +Literary Analysis + +# Persona + +You are a literary scholar with a focus on works from the 19th century. You are skilled at analyzing and interpreting texts, identifying themes and motifs, and understanding the historical and cultural contexts in which these works were written. You are adept at helping people understand the deeper meanings and significance of literary works, providing insights into the author's intentions, the social issues addressed in the text, and the impact of these works on contemporary society. + +# Data + +Had she found Jane in any apparent danger, Mrs. Bennet would have been very miserable; but being satisfied on seeing her that her illness was not alarming, she had no wish of her recovering immediately, as her restoration to health would probably remove her from Netherfield. She would not listen, therefore, to her daughter's proposal of being carried home; neither did the apothecary, who arrived about the same time, think it at all advisable. After sitting a little with Jane, on Miss Bingley's appearance and invitation, the mother and three daughters all attended her into the breakfast parlor. Bingley met them with hopes that Mrs. Bennet had not found Miss Bennet worse than she expected. + +"Indeed I have, Sir," was her answer. "She is a great deal too ill to be moved. Mr. Jones says we must not think of moving her. We must trespass a little longer on your kindness." + +"Removed!" cried Bingley. "It must not be thought of. My sister, I am sure, will not hear of her removal." + +# Importance Criteria + +A float score between 0-10 that represents the relevance of the text to literary analysis, historical context, thematic interpretation, and cultural significance, with 1 being trivial or irrelevant and 10 being highly significant, profound, and impactful to the understanding of the text and its implications. +############################# + +### Example 3 + +# Domain + +Environmental Science + +# Persona + +You are an environmental scientist with a focus on climate change and sustainability. You are skilled at analyzing data, interpreting social commentary and recommending policy changes. You are adept at helping people understand the causes and consequences of climate change, providing insights into how they can reduce their carbon footprint, adopt sustainable practices, and contribute to a healthier planet. + +# Data + +Host 1 (Anna): Welcome to "Green Living Today," the podcast where we explore practical tips and inspiring stories about sustainable living. I'm your host, Anna Green. + +Host 2 (Mark): And I'm Mark Smith. Today, we have a special episode focused on reducing plastic waste in our daily lives. We'll be talking to a special guest who has made significant strides in living a plastic-free lifestyle. + +Anna: That's right, Mark. Our guest today is Laura Thompson, the founder of "Plastic-Free Living," a blog dedicated to sharing tips and resources for reducing plastic use. Welcome to the show, Laura! + +Guest (Laura): Thanks, Anna and Mark. It's great to be here. + +Mark: Laura, let's start by talking about your journey. What inspired you to start living a plastic-free lifestyle? + +# Importance Criteria + +A float score between 0-10 that represents the relevance of the text to sustainability, plastic waste reduction, and environmental policies, with 1 being trivial or irrelevant and 10 being highly significant, impactful, and actionable in promoting environmental awareness. +############################# + + +############################# +-Real Data- +############################# + +# Domain + +{domain} + +# Persona + +{persona} + +# Data + +{input_text} + +# Importance Criteria + + +""" diff --git a/func-app/graphrag/prompt_tune/prompt/community_reporter_role.py b/func-app/graphrag/prompt_tune/prompt/community_reporter_role.py new file mode 100644 index 0000000000..b667bc2940 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/community_reporter_role.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for community reporter role generation.""" + +GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT = """ +{persona} +Given a sample text, help the user by creating a role definition that will be tasked with community analysis. +Take a look at this example, determine its key parts, and using the domain provided and your expertise, create a new role definition for the provided inputs that follows the same pattern as the example. +Remember, your output should look just like the provided example in structure and content. + +Example: +A technologist reporter that is analyzing Kevin Scott's "Behind the Tech Podcast", given a list of entities +that belong to the community as well as their relationships and optional associated claims. +The report will be used to inform decision-makers about significant developments associated with the community and their potential impact. + + +Domain: {domain} +Text: {input_text} +Role:""" diff --git a/func-app/graphrag/prompt_tune/prompt/domain.py b/func-app/graphrag/prompt_tune/prompt/domain.py new file mode 100644 index 0000000000..4b4587f8d8 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/domain.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for domain generation.""" + +GENERATE_DOMAIN_PROMPT = """ +You are an intelligent assistant that helps a human to analyze the information in a text document. +Given a sample text, help the user by assigning a descriptive domain that summarizes what the text is about. +Example domains are: "Social studies", "Algorithmic analysis", "Medical science", among others. + +Text: {input_text} +Domain:""" diff --git a/func-app/graphrag/prompt_tune/prompt/entity_relationship.py b/func-app/graphrag/prompt_tune/prompt/entity_relationship.py new file mode 100644 index 0000000000..3af77db641 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/entity_relationship.py @@ -0,0 +1,355 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for entity relationship generation.""" + +ENTITY_RELATIONSHIPS_GENERATION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity +Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +5. When finished, output {{completion_delimiter}}. + +###################### +-Examples- +###################### +Example 1: +Entity_types: ORGANIZATION,PERSON +Text: +The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. +###################### +Output: +("entity"{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Martin Smith is the chair of the Central Institution) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MARKET STRATEGY COMMITTEE{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{{tuple_delimiter}}9) +{{completion_delimiter}} + +###################### +Example 2: +Entity_types: ORGANIZATION +Text: +TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform. + +TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. +###################### +Output: +("entity"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) +{{record_delimiter}} +("entity"{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}Vision Holdings is a firm that previously owned TechGlobal) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}Vision Holdings formerly owned TechGlobal from 2014 until present{{tuple_delimiter}}5) +{{completion_delimiter}} + +###################### +Example 3: +Entity_types: ORGANIZATION,GEO,PERSON +Text: +Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia. + +The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara. + +The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara. + +They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion. + +The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. +###################### +Output: +("entity"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}GEO{{tuple_delimiter}}Firuzabad held Aurelians as hostages) +{{record_delimiter}} +("entity"{{tuple_delimiter}}AURELIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country seeking to release hostages) +{{record_delimiter}} +("entity"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country that negotiated a swap of money in exchange for hostages) +{{record_delimiter}} +{{record_delimiter}} +("entity"{{tuple_delimiter}}TIRUZIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital of Firuzabad where the Aurelians were being held) +{{record_delimiter}} +("entity"{{tuple_delimiter}}KROHAARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Quintara) +{{record_delimiter}} +("entity"{{tuple_delimiter}}CASHION{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Aurelia) +{{record_delimiter}} +("entity"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian who spent time in Tiruzia's Alhamia Prison) +{{record_delimiter}} +("entity"{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}GEO{{tuple_delimiter}}Prison in Tiruzia) +{{record_delimiter}} +("entity"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian journalist who was held hostage) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Bratinas national and environmentalist who was held hostage) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Firuzabad negotiated a hostage exchange with Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}Samuel Namara was a prisoner at Alhamia prison{{tuple_delimiter}}8) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Samuel Namara was a hostage in Firuzabad{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Meggie Tazbah was a hostage in Firuzabad{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Durke Bataglani was a hostage in Firuzabad{{tuple_delimiter}}2) +{{completion_delimiter}} + +-Real Data- +###################### +entity_types: {entity_types} +text: {input_text} +###################### +output: +""" + +ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities + +Format each entity output as a JSON entry with the following format: + +{{"name": , "type": , "description": }} + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity + +Format each relationship as a JSON entry with the following format: + +{{"source": , "target": , "relationship": , "relationship_strength": }} + +3. Return output in {language} as a single list of all JSON entities and relationships identified in steps 1 and 2. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +###################### +-Examples- +###################### +Example 1: +Text: +The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. +###################### +Output: +[ + {{"name": "CENTRAL INSTITUTION", "type": "ORGANIZATION", "description": "The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday"}}, + {{"name": "MARTIN SMITH", "type": "PERSON", "description": "Martin Smith is the chair of the Central Institution"}}, + {{"name": "MARKET STRATEGY COMMITTEE", "type": "ORGANIZATION", "description": "The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply"}}, + {{"source": "MARTIN SMITH", "target": "CENTRAL INSTITUTION", "relationship": "Martin Smith is the Chair of the Central Institution and will answer questions at a press conference", "relationship_strength": 9}} +] + +###################### +Example 2: +Text: +TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform. + +TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. +###################### +Output: +[ + {{"name": "TECHGLOBAL", "type": "ORGANIZATION", "description": "TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones"}}, + {{"name": "VISION HOLDINGS", "type": "ORGANIZATION", "description": "Vision Holdings is a firm that previously owned TechGlobal"}}, + {{"source": "TECHGLOBAL", "target": "VISION HOLDINGS", "relationship": "Vision Holdings formerly owned TechGlobal from 2014 until present", "relationship_strength": 5}} +] + +###################### +Example 3: +Text: +Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia. + +The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara. + +The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara. + +They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion. + +The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. +###################### +Output: +[ + {{"name": "FIRUZABAD", "type": "GEO", "description": "Firuzabad held Aurelians as hostages"}}, + {{"name": "AURELIA", "type": "GEO", "description": "Country seeking to release hostages"}}, + {{"name": "QUINTARA", "type": "GEO", "description": "Country that negotiated a swap of money in exchange for hostages"}}, + {{"name": "TIRUZIA", "type": "GEO", "description": "Capital of Firuzabad where the Aurelians were being held"}}, + {{"name": "KROHAARA", "type": "GEO", "description": "Capital city in Quintara"}}, + {{"name": "CASHION", "type": "GEO", "description": "Capital city in Aurelia"}}, + {{"name": "SAMUEL NAMARA", "type": "PERSON", "description": "Aurelian who spent time in Tiruzia's Alhamia Prison"}}, + {{"name": "ALHAMIA PRISON", "type": "GEO", "description": "Prison in Tiruzia"}}, + {{"name": "DURKE BATAGLANI", "type": "PERSON", "description": "Aurelian journalist who was held hostage"}}, + {{"name": "MEGGIE TAZBAH", "type": "PERSON", "description": "Bratinas national and environmentalist who was held hostage"}}, + {{"source": "FIRUZABAD", "target": "AURELIA", "relationship": "Firuzabad negotiated a hostage exchange with Aurelia", "relationship_strength": 2}}, + {{"source": "QUINTARA", "target": "AURELIA", "relationship": "Quintara brokered the hostage exchange between Firuzabad and Aurelia", "relationship_strength": 2}}, + {{"source": "QUINTARA", "target": "FIRUZABAD", "relationship": "Quintara brokered the hostage exchange between Firuzabad and Aurelia", "relationship_strength": 2}}, + {{"source": "SAMUEL NAMARA", "target": "ALHAMIA PRISON", "relationship": "Samuel Namara was a prisoner at Alhamia prison", "relationship_strength": 8}}, + {{"source": "SAMUEL NAMARA", "target": "MEGGIE TAZBAH", "relationship": "Samuel Namara and Meggie Tazbah were exchanged in the same hostage release", "relationship_strength": 2}}, + {{"source": "SAMUEL NAMARA", "target": "DURKE BATAGLANI", "relationship": "Samuel Namara and Durke Bataglani were exchanged in the same hostage release", "relationship_strength": 2}}, + {{"source": "MEGGIE TAZBAH", "target": "DURKE BATAGLANI", "relationship": "Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release", "relationship_strength": 2}}, + {{"source": "SAMUEL NAMARA", "target": "FIRUZABAD", "relationship": "Samuel Namara was a hostage in Firuzabad", "relationship_strength": 2}}, + {{"source": "MEGGIE TAZBAH", "target": "FIRUZABAD", "relationship": "Meggie Tazbah was a hostage in Firuzabad", "relationship_strength": 2}}, + {{"source": "DURKE BATAGLANI", "target": "FIRUZABAD", "relationship": "Durke Bataglani was a hostage in Firuzabad", "relationship_strength": 2}} +] + + + +-Real Data- +###################### +entity_types: {entity_types} +text: {input_text} +###################### +output: +""" + +UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity, first identify all entities needed from the text in order to capture the information and ideas in the text. +Next, report all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: Suggest several labels or categories for the entity. The categories should not be specific, but should be as general as possible. +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity +Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +5. When finished, output {{completion_delimiter}}. + +###################### +-Examples- +###################### +Example 1: +Text: +The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. +###################### +Output: +("entity"{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Martin Smith is the chair of the Central Institution) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MARKET STRATEGY COMMITTEE{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{{tuple_delimiter}}9) +{{completion_delimiter}} + +###################### +Example 2: +Text: +TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform. + +TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. +###################### +Output: +("entity"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) +{{record_delimiter}} +("entity"{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}Vision Holdings is a firm that previously owned TechGlobal) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}Vision Holdings formerly owned TechGlobal from 2014 until present{{tuple_delimiter}}5) +{{completion_delimiter}} + +###################### +Example 3: +Text: +Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia. + +The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara. + +The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara. + +They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion. + +The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. +###################### +Output: +("entity"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}GEO{{tuple_delimiter}}Firuzabad held Aurelians as hostages) +{{record_delimiter}} +("entity"{{tuple_delimiter}}AURELIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country seeking to release hostages) +{{record_delimiter}} +("entity"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country that negotiated a swap of money in exchange for hostages) +{{record_delimiter}} +{{record_delimiter}} +("entity"{{tuple_delimiter}}TIRUZIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital of Firuzabad where the Aurelians were being held) +{{record_delimiter}} +("entity"{{tuple_delimiter}}KROHAARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Quintara) +{{record_delimiter}} +("entity"{{tuple_delimiter}}CASHION{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Aurelia) +{{record_delimiter}} +("entity"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian who spent time in Tiruzia's Alhamia Prison) +{{record_delimiter}} +("entity"{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}GEO{{tuple_delimiter}}Prison in Tiruzia) +{{record_delimiter}} +("entity"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian journalist who was held hostage) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Bratinas national and environmentalist who was held hostage) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Firuzabad negotiated a hostage exchange with Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}Samuel Namara was a prisoner at Alhamia prison{{tuple_delimiter}}8) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Samuel Namara was a hostage in Firuzabad{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Meggie Tazbah was a hostage in Firuzabad{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Durke Bataglani was a hostage in Firuzabad{{tuple_delimiter}}2) +{{completion_delimiter}} + +###################### +-Real Data- +###################### +Text: {input_text} +###################### +Output: +""" diff --git a/func-app/graphrag/prompt_tune/prompt/entity_types.py b/func-app/graphrag/prompt_tune/prompt/entity_types.py new file mode 100644 index 0000000000..99b21db645 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/entity_types.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for entity types generation.""" + +ENTITY_TYPE_GENERATION_PROMPT = """ +The goal is to study the connections and relations between the entity types and their features in order to understand all available information from the text. +The user's task is to {task}. +As part of the analysis, you want to identify the entity types present in the following text. +The entity types must be relevant to the user task. +Avoid general entity types such as "other" or "unknown". +This is VERY IMPORTANT: Do not generate redundant or overlapping entity types. For example, if the text contains "company" and "organization" entity types, you should return only one of them. +Don't worry about quantity, always choose quality over quantity. And make sure EVERYTHING in your answer is relevant to the context of entity extraction. +And remember, it is ENTITY TYPES what we need. +Return the entity types in as a list of comma sepparated of strings. +===================================================================== +EXAMPLE SECTION: The following section includes example output. These examples **must be excluded from your answer**. + +EXAMPLE 1 +Task: Determine the connections and organizational hierarchy within the specified community. +Text: Example_Org_A is a company in Sweden. Example_Org_A's director is Example_Individual_B. +RESPONSE: +organization, person +END OF EXAMPLE 1 + +EXAMPLE 2 +Task: Identify the key concepts, principles, and arguments shared among different philosophical schools of thought, and trace the historical or ideological influences they have on each other. +Text: Rationalism, epitomized by thinkers such as René Descartes, holds that reason is the primary source of knowledge. Key concepts within this school include the emphasis on the deductive method of reasoning. +RESPONSE: +concept, person, school of thought +END OF EXAMPLE 2 + +EXAMPLE 3 +Task: Identify the full range of basic forces, factors, and trends that would indirectly shape an issue. +Text: Industry leaders such as Panasonic are vying for supremacy in the battery production sector. They are investing heavily in research and development and are exploring new technologies to gain a competitive edge. +RESPONSE: +organization, technology, sectors, investment strategies +END OF EXAMPLE 3 +====================================================================== + +====================================================================== +REAL DATA: The following section is the real data. You should use only this real data to prepare your answer. Generate Entity Types only. +Task: {task} +Text: {input_text} +RESPONSE: +{{}} +""" + +ENTITY_TYPE_GENERATION_JSON_PROMPT = """ +The goal is to study the connections and relations between the entity types and their features in order to understand all available information from the text. +The user's task is to {task}. +As part of the analysis, you want to identify the entity types present in the following text. +The entity types must be relevant to the user task. +Avoid general entity types such as "other" or "unknown". +This is VERY IMPORTANT: Do not generate redundant or overlapping entity types. For example, if the text contains "company" and "organization" entity types, you should return only one of them. +Don't worry about quantity, always choose quality over quantity. And make sure EVERYTHING in your answer is relevant to the context of entity extraction. +Return the entity types in JSON format with "entities" as the key and the entity types as an array of strings. +===================================================================== +EXAMPLE SECTION: The following section includes example output. These examples **must be excluded from your answer**. + +EXAMPLE 1 +Task: Determine the connections and organizational hierarchy within the specified community. +Text: Example_Org_A is a company in Sweden. Example_Org_A's director is Example_Individual_B. +JSON RESPONSE: +{{"entity_types": [organization, person] }} +END OF EXAMPLE 1 + +EXAMPLE 2 +Task: Identify the key concepts, principles, and arguments shared among different philosophical schools of thought, and trace the historical or ideological influences they have on each other. +Text: Rationalism, epitomized by thinkers such as René Descartes, holds that reason is the primary source of knowledge. Key concepts within this school include the emphasis on the deductive method of reasoning. +JSON RESPONSE: +{{"entity_types": [concept, person, school of thought] }} +END OF EXAMPLE 2 + +EXAMPLE 3 +Task: Identify the full range of basic forces, factors, and trends that would indirectly shape an issue. +Text: Industry leaders such as Panasonic are vying for supremacy in the battery production sector. They are investing heavily in research and development and are exploring new technologies to gain a competitive edge. +JSON RESPONSE: +{{"entity_types": [organization, technology, sectors, investment strategies] }} +END OF EXAMPLE 3 +====================================================================== + +====================================================================== +REAL DATA: The following section is the real data. You should use only this real data to prepare your answer. Generate Entity Types only. +Task: {task} +Text: {input_text} +JSON response: +{{"entity_types": [] }} +""" diff --git a/func-app/graphrag/prompt_tune/prompt/language.py b/func-app/graphrag/prompt_tune/prompt/language.py new file mode 100644 index 0000000000..68fd04029f --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/language.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for language detection.""" + +DETECT_LANGUAGE_PROMPT = """ +You are an intelligent assistant that helps a human to analyze the information in a text document. +Given a sample text, help the user by determining what's the primary language of the provided texts. +Examples are: "English", "Spanish", "Japanese", "Portuguese" among others. + +Text: {input_text} +Language:""" diff --git a/func-app/graphrag/prompt_tune/prompt/persona.py b/func-app/graphrag/prompt_tune/prompt/persona.py new file mode 100644 index 0000000000..58515fd204 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/persona.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for persona generation.""" + +GENERATE_PERSONA_PROMPT = """ +You are an intelligent assistant that helps a human to analyze the information in a text document. +Given a specific type of task and sample text, help the user by generating a 3 to 4 sentence description of an expert who could help solve the problem. +Use a format similar to the following: +You are an expert {{role}}. You are skilled at {{relevant skills}}. You are adept at helping people with {{specific task}}. + +task: {sample_task} +persona description:""" diff --git a/func-app/graphrag/prompt_tune/template/__init__.py b/func-app/graphrag/prompt_tune/template/__init__.py new file mode 100644 index 0000000000..e056762ff7 --- /dev/null +++ b/func-app/graphrag/prompt_tune/template/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for entity extraction, entity summarization, and community report summarization.""" + +from .community_report_summarization import COMMUNITY_REPORT_SUMMARIZATION_PROMPT +from .entity_extraction import ( + EXAMPLE_EXTRACTION_TEMPLATE, + GRAPH_EXTRACTION_JSON_PROMPT, + GRAPH_EXTRACTION_PROMPT, + UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE, + UNTYPED_GRAPH_EXTRACTION_PROMPT, +) +from .entity_summarization import ENTITY_SUMMARIZATION_PROMPT + +__all__ = [ + "COMMUNITY_REPORT_SUMMARIZATION_PROMPT", + "ENTITY_SUMMARIZATION_PROMPT", + "EXAMPLE_EXTRACTION_TEMPLATE", + "GRAPH_EXTRACTION_JSON_PROMPT", + "GRAPH_EXTRACTION_PROMPT", + "UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE", + "UNTYPED_GRAPH_EXTRACTION_PROMPT", +] diff --git a/func-app/graphrag/prompt_tune/template/community_report_summarization.py b/func-app/graphrag/prompt_tune/template/community_report_summarization.py new file mode 100644 index 0000000000..14e039ba41 --- /dev/null +++ b/func-app/graphrag/prompt_tune/template/community_report_summarization.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for community report summarization.""" + +COMMUNITY_REPORT_SUMMARIZATION_PROMPT = """ +{persona} + +# Goal +Write a comprehensive assessment report of a community taking on the role of a {role}. The content of this report includes an overview of the community's key entities and relationships. + +# Report Structure +The report should include the following sections: +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant points associated with its entities. +- REPORT RATING: {report_rating_description} +- RATING EXPLANATION: Give a single sentence explanation of the rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format. Don't use any unnecessary escape sequences. The output should be a single JSON object that can be parsed by json.loads. + {{ + "title": "", + "summary": "", + "rating": , + "rating_explanation": "" + "findings": "[{{"summary":"", "explanation": "", "explanation": " (, ... ()]. If there are more than 10 data records, show the top 10 most relevant records. +Each paragraph should contain multiple sentences of explanation and concrete examples with specific named entities. All paragraphs must have these references at the start and end. Use "NONE" if there are no related roles or records. Everything should be in {language}. + +Example paragraph with references added: +This is a paragraph of the output text [records: Entities (1, 2, 3), Claims (2, 5), Relationships (10, 12)] + +# Example Input +----------- +Text: + +Entities + +id,entity,description +5,ABILA CITY PARK,Abila City Park is the location of the POK rally + +Relationships + +id,source,target,description +37,ABILA CITY PARK,POK RALLY,Abila City Park is the location of the POK rally +38,ABILA CITY PARK,POK,POK is holding a rally in Abila City Park +39,ABILA CITY PARK,POKRALLY,The POKRally is taking place at Abila City Park +40,ABILA CITY PARK,CENTRAL BULLETIN,Central Bulletin is reporting on the POK rally taking place in Abila City Park + +Output: +{{ + "title": "Abila City Park and POK Rally", + "summary": "The community revolves around the Abila City Park, which is the location of the POK rally. The park has relationships with POK, POKRALLY, and Central Bulletin, all +of which are associated with the rally event.", + "rating": 5.0, + "rating_explanation": "The impact rating is moderate due to the potential for unrest or conflict during the POK rally.", + "findings": [ + {{ + "summary": "Abila City Park as the central location", + "explanation": "Abila City Park is the central entity in this community, serving as the location for the POK rally. This park is the common link between all other +entities, suggesting its significance in the community. The park's association with the rally could potentially lead to issues such as public disorder or conflict, depending on the +nature of the rally and the reactions it provokes. [records: Entities (5), Relationships (37, 38, 39, 40)]" + }}, + {{ + "summary": "POK's role in the community", + "explanation": "POK is another key entity in this community, being the organizer of the rally at Abila City Park. The nature of POK and its rally could be a potential +source of threat, depending on their objectives and the reactions they provoke. The relationship between POK and the park is crucial in understanding the dynamics of this community. +[records: Relationships (38)]" + }}, + {{ + "summary": "POKRALLY as a significant event", + "explanation": "The POKRALLY is a significant event taking place at Abila City Park. This event is a key factor in the community's dynamics and could be a potential +source of threat, depending on the nature of the rally and the reactions it provokes. The relationship between the rally and the park is crucial in understanding the dynamics of this +community. [records: Relationships (39)]" + }}, + {{ + "summary": "Role of Central Bulletin", + "explanation": "Central Bulletin is reporting on the POK rally taking place in Abila City Park. This suggests that the event has attracted media attention, which could +amplify its impact on the community. The role of Central Bulletin could be significant in shaping public perception of the event and the entities involved. [records: Relationships +(40)]" + }} + ] + +}} + +# Real Data + +Use the following text for your answer. Do not make anything up in your answer. + +Text: +{{input_text}} +Output:""" diff --git a/func-app/graphrag/prompt_tune/template/entity_extraction.py b/func-app/graphrag/prompt_tune/template/entity_extraction.py new file mode 100644 index 0000000000..32d8756ec2 --- /dev/null +++ b/func-app/graphrag/prompt_tune/template/entity_extraction.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for entity extraction.""" + +GRAPH_EXTRACTION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity +Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +5. When finished, output {{completion_delimiter}}. + +-Examples- +###################### +{examples} + +-Real Data- +###################### +entity_types: [{entity_types}] +text: {{input_text}} +###################### +output:""" + +GRAPH_EXTRACTION_JSON_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity output as a JSON entry with the following format: + +{{"name": , "type": , "description": }} + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity +Format each relationship as a JSON entry with the following format: + +{{"source": , "target": , "relationship": , "relationship_strength": }} + +3. Return output in {language} as a single list of all JSON entities and relationships identified in steps 1 and 2. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +-Examples- +###################### +{examples} + +-Real Data- +###################### +entity_types: {entity_types} +text: {{input_text}} +###################### +output:""" + +EXAMPLE_EXTRACTION_TEMPLATE = """ +Example {n}: + +entity_types: [{entity_types}] +text: +{input_text} +------------------------ +output: +{output} +############################# + +""" + +UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE = """ +Example {n}: + +text: +{input_text} +------------------------ +output: +{output} +############################# + +""" + + +UNTYPED_GRAPH_EXTRACTION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity, first identify all entities needed from the text in order to capture the information and ideas in the text. +Next, report all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: Suggest several labels or categories for the entity. The categories should not be specific, but should be as general as possible. +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity +Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +5. When finished, output {{completion_delimiter}}. + +-Examples- +###################### +{examples} + +-Real Data- +###################### +text: {{input_text}} +###################### +output: +""" diff --git a/func-app/graphrag/prompt_tune/template/entity_summarization.py b/func-app/graphrag/prompt_tune/template/entity_summarization.py new file mode 100644 index 0000000000..60294a291b --- /dev/null +++ b/func-app/graphrag/prompt_tune/template/entity_summarization.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for entity summarization.""" + +ENTITY_SUMMARIZATION_PROMPT = """ +{persona} +Using your expertise, you're asked to generate a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, concise description in {language}. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. +Make sure it is written in third person, and include the entity names so we the have full context. + +Enrich it as much as you can with relevant information from the nearby text, this is very important. + +If no answer is possible, or the description is empty, only convey information that is provided within the text. +####### +-Data- +Entities: {{entity_name}} +Description List: {{description_list}} +####### +Output:""" diff --git a/func-app/graphrag/query/__init__.py b/func-app/graphrag/query/__init__.py new file mode 100644 index 0000000000..58a557f8a2 --- /dev/null +++ b/func-app/graphrag/query/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG Orchestration Module.""" diff --git a/func-app/graphrag/query/__main__.py b/func-app/graphrag/query/__main__.py new file mode 100644 index 0000000000..e2e01d6fa6 --- /dev/null +++ b/func-app/graphrag/query/__main__.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Query Engine package root.""" + +import argparse +from enum import Enum + +from .cli import run_global_search, run_local_search + +INVALID_METHOD_ERROR = "Invalid method" + + +class SearchType(Enum): + """The type of search to run.""" + + LOCAL = "local" + GLOBAL = "global" + + def __str__(self): + """Return the string representation of the enum value.""" + return self.value + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--config", + help="The configuration yaml file to use when running the query", + required=False, + type=str, + ) + + parser.add_argument( + "--data", + help="The path with the output data from the pipeline", + required=False, + type=str, + ) + + parser.add_argument( + "--root", + help="The data project root. Default value: the current directory", + required=False, + default=".", + type=str, + ) + + parser.add_argument( + "--method", + help="The method to run, one of: local or global", + required=True, + type=SearchType, + choices=list(SearchType), + ) + + parser.add_argument( + "--community_level", + help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities", + type=int, + default=2, + ) + + parser.add_argument( + "--response_type", + help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report", + type=str, + default="Multiple Paragraphs", + ) + + parser.add_argument( + "--context_id", + help="Guid describing context in which the search should be performed", + type=str, + #default="00000000-0000-0000-0000-000000000000", + ) + + parser.add_argument( + "--optimized_search", + help="Runs optimized search and export artifacts", + type=bool, + default=False, + ) + + parser.add_argument( + "--use_kusto_community_reports", + help="If enabled community reports are attempted to be used in Kusto during query", + action="store_true", + ) + + parser.add_argument( + "--paths", + help="Different paths for the query", + type=int, + default=0, # Default to normal graphrag search + ) + + parser.add_argument( + "query", + nargs=1, + help="The query to run", + type=str, + ) + + args = parser.parse_args() + + match args.method: + case SearchType.LOCAL: + run_local_search( + args.config, + args.data, + args.root, + args.community_level, + args.response_type, + args.context_id, + args.query[0], + optimized_search=args.optimized_search, + use_kusto_community_reports=args.use_kusto_community_reports, + paths=args.paths, + ) + case SearchType.GLOBAL: + run_global_search( + args.config, + args.data, + args.root, + args.community_level, + args.response_type, + args.context_id, + args.query[0], + ) + case _: + raise ValueError(INVALID_METHOD_ERROR) diff --git a/func-app/graphrag/query/cli.py b/func-app/graphrag/query/cli.py new file mode 100644 index 0000000000..a8a06d69cd --- /dev/null +++ b/func-app/graphrag/query/cli.py @@ -0,0 +1,472 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Command line interface for the query module.""" + +import asyncio +import os +from pathlib import Path +from typing import cast +from io import BytesIO + +from datashaper import VerbCallbacks +from graphrag.common.progress.rich import RichProgressReporter +from graphrag.common.storage import PipelineStorage, BlobPipelineStorage, FilePipelineStorage +from graphrag.common.utils.context_utils import get_files_by_contextid +from graphrag.config.enums import StorageType +from azure.core.exceptions import ResourceNotFoundError + +import pandas as pd + +from graphrag.config import ( + create_graphrag_config, + GraphRagConfig, +) +from graphrag.common.progress import PrintProgressReporter +from graphrag.index.verbs.entities.extraction.strategies.graph_intelligence.run_graph_intelligence import run_gi +from graphrag.index.verbs.entities.extraction.strategies.typing import Document +from graphrag.model.entity import Entity +from graphrag.query.input.loaders.dfs import ( + store_entity_semantic_embeddings, +) +from graphrag.vector_stores import VectorStoreFactory, VectorStoreType +from graphrag.vector_stores.base import BaseVectorStore +from graphrag.vector_stores.lancedb import LanceDBVectorStore +from graphrag.vector_stores.kusto import KustoVectorStore +from .factories import get_global_search_engine, get_local_search_engine +from .indexer_adapters import ( + read_indexer_covariates, + read_indexer_entities, + read_indexer_relationships, + read_indexer_reports, + kt_read_indexer_reports, + read_indexer_text_units, +) + +from common.graph_db_client import GraphDBClient + +reporter = PrintProgressReporter("") + +reporter = PrintProgressReporter("") + +def __get_embedding_description_store( + entities: list[Entity] = [], + vector_store_type: str = VectorStoreType.LanceDB, + config_args: dict | None = None, + context_id: str = "", +): + """Get the embedding description store.""" + if not config_args: + config_args = {} + + collection_name = config_args.get( + "query_collection_name", "entity_description_embeddings" + ) + config_args.update({"collection_name": f"{collection_name}_{context_id}" if context_id else collection_name}) + vector_name = config_args.get( + "vector_search_column", "description_embedding" + ) + config_args.update({"vector_name": vector_name}) + config_args.update({"reports_name": f"reports_{context_id}" if context_id else "reports"}) + config_args.update({"text_units_name": f"text_units_{context_id}"}) + + description_embedding_store = VectorStoreFactory.get_vector_store( + vector_store_type=vector_store_type, kwargs=config_args + ) + + description_embedding_store.connect(**config_args) + + if vector_store_type == VectorStoreType.Kusto: + return description_embedding_store + + elif config_args.get("overwrite", True): + # this step assumps the embeddings where originally stored in a file rather + # than a vector database + + # dump embeddings from the entities list to the description_embedding_store + store_entity_semantic_embeddings( + entities=entities, vectorstore=description_embedding_store + ) + else: + # load description embeddings to an in-memory lancedb vectorstore + # to connect to a remote db, specify url and port values. + description_embedding_store = LanceDBVectorStore( + collection_name=collection_name + ) + description_embedding_store.connect( + db_uri=config_args.get("db_uri", "./lancedb") + ) + + # load data from an existing table + description_embedding_store.document_collection = ( + description_embedding_store.db_connection.open_table( + description_embedding_store.collection_name + ) + ) + + return description_embedding_store + + +def run_global_search( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, +): + """Run a global search with the given query.""" + data_dir, root_dir, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + if config.graphdb.enabled: + graph_db_client = GraphDBClient(config.graphdb) + data_path = Path(data_dir) + + final_nodes: pd.DataFrame = pd.read_parquet( + data_path / "create_final_nodes.parquet" + ) + if config.graphdb.enabled: + final_entities = graph_db_client.query_vertices() + else: + final_entities: pd.DataFrame = pd.read_parquet( + data_path / "create_final_entities.parquet" + ) + final_community_reports: pd.DataFrame = pd.read_parquet( + data_path / "create_final_community_reports.parquet" + ) + + reports = read_indexer_reports( + final_community_reports, final_nodes, community_level + ) + entities = read_indexer_entities(final_nodes, final_entities, community_level) + search_engine = get_global_search_engine( + config, + reports=reports, + entities=entities, + response_type=response_type, + ) + + result = search_engine.search(query=query) + + reporter.success(f"Global Search Response: {result.response}") + return result.response + +def path0( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + + """Run a local search with the given query.""" + data_dir, root_dir, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + + vector_store_args = ( + config.embeddings.vector_store if config.embeddings.vector_store else {} + ) + + reporter.info(f"Vector Store Args: {vector_store_args}") + vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) + + entities=[] + text_units=[] + covariates=[] + reports=[] + final_relationships=[] + + if(config.storage.type == StorageType.blob): + if(config.storage.container_name is not None): + output_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, + container_name=config.storage.container_name, + storage_account_blob_url=config.storage.storage_account_blob_url) + else: + ValueError("Storage type is Blob but container name is invalid") + elif(config.storage.type == StorageType.file): + output_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + + + + ##### LEGACY ####################### + + if vector_store_type == VectorStoreType.LanceDB: + # for the POC purpose input artifacts blob, output artifacts blob and input query blob storage are going to same. + if(config.storage.type == StorageType.memory): + ValueError("Memory storage is not supported") + if(config.storage.type == StorageType.blob): + if(config.storage.container_name is not None): + input_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, + container_name=config.storage.container_name, + storage_account_blob_url=config.storage.storage_account_blob_url) + else: + ValueError("Storage type is Blob but container name is invalid") + if(config.storage.type == StorageType.file): + input_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + + + data_paths = [] + data_paths = get_files_by_contextid(config, context_id) + final_nodes = pd.DataFrame() + final_community_reports = pd.DataFrame() + final_text_units = pd.DataFrame() + final_relationships = pd.DataFrame() + final_entities = pd.DataFrame() + final_covariates = pd.DataFrame() + + for data_path in data_paths: + #check from the config for the ouptut storage type and then read the data from the storage. + + #GraphDB: we may need to make change below to read nodes data from Graph DB + final_nodes = pd.concat([final_nodes, read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet")]) + final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto + final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. + final_relationships = pd.concat([final_relationships,read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) + + if not optimized_search: + final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) + + final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) + + ############# End of for loop + + entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. + reports=read_indexer_reports( + final_community_reports, final_nodes, community_level + ) + + final_relationships=read_indexer_relationships(final_relationships) + + covariates = ( + read_indexer_covariates(final_covariates) + if final_covariates.empty is False + else [] + ) + text_units=read_indexer_text_units(final_text_units) + + + ######################################################################################## + + if use_kusto_community_reports: + raise ValueError("Using community reports is not supported.") + + description_embedding_store = __get_embedding_description_store( + entities=entities, + vector_store_type=vector_store_type, + config_args=vector_store_args, + context_id=context_id, + ) + + ''' + *** If KUSTO is enabled, both entities and final_relationships must be empty. + ''' + search_engine = get_local_search_engine( + config, + reports=reports, + text_units=text_units, + entities=entities, + relationships=final_relationships, + covariates={"claims": covariates}, + description_embedding_store=description_embedding_store, + response_type=response_type, + context_id=context_id, + is_optimized_search=optimized_search, + use_kusto_community_reports=use_kusto_community_reports, + graphdb_config=config.graphdb, + ) + + if optimized_search: + result = search_engine.optimized_search(query=query) + else: + result = search_engine.search(query=query) + for key in result.context_data.keys(): + asyncio.run(output_storage_client.set("query/output/"+ key +".paraquet", result.context_data[key].to_parquet())) #it shows as error in editor but not an error. + reporter.success(f"Local Search Response: {result.response}") + return result.response + +def path1( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + ValueError("Not implemented") + +def path2( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + """Path 2 + Find all the emails sent to trader by Tim Belden + a. Query -> LLM -> Entity Extracted -> 5 entities -> Set A [TimBelden1] + b. Query -> LLM -> Embeddings -> Y [x1..... Xn] + c. Run the query on Kusto for embedding Y [x1.....xn] for entitYid in [TimBelden1] + 4. Get the text units and get the response""" + data_dir, root_dir, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + + # Populate args with dict of arguments for the LLM + args = {} + args['api_key'] = config.llm.api_key + args['type'] = config.llm.type + args['model'] = config.llm.model + args['model_supports_json'] = config.llm.model_supports_json + args['api_base'] = config.llm.api_base + args['api_version'] = config.llm.api_version + args['deployment_name'] = config.llm.deployment_name + llmm = {} + llmm['llm'] = args + + + result = asyncio.run(run_gi( + docs=[Document(text=query, id='0')], + entity_types=config.entity_extraction.entity_types, + reporter = None, + pipeline_cache=None, + args=llmm, + )) + + print(result.entities) + exit(0) + +def path3( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + ValueError("Not implemented") + + +def run_local_search( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + paths: int = 0,): + """Run a local search with the given query.""" + if(paths==1): + return path1(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + elif(paths==2): + return path2(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + elif(paths==3): + return path3(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + return path0(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + +def blob_exists(container_client, blob_name): + blob_client = container_client.get_blob_client(blob_name) + try: + # Attempt to get the blob properties + blob_client.get_blob_properties() + return True + except ResourceNotFoundError: + # Blob does not exist + return False + + +def read_paraquet_file(storage: PipelineStorage, path: str): + #create different enum for paraquet storage type + file_data = asyncio.run(storage.get(path, True)) + if file_data is None: + return pd.DataFrame() + return pd.read_parquet(BytesIO(file_data), engine="pyarrow") + +def _configure_paths_and_settings( + data_dir: str | None, + root_dir: str | None, + config_dir: str | None, +) -> tuple[str, str | None, GraphRagConfig]: + if data_dir is None and root_dir is None: + msg = "Either data_dir or root_dir must be provided." + raise ValueError(msg) + if data_dir is None: + data_dir = _infer_data_dir(cast(str, root_dir)) + config = _create_graphrag_config(root_dir, config_dir) + return data_dir, root_dir, config + + +def _infer_data_dir(root: str) -> str: + output = Path(root) / "output" + # use the latest data-run folder + if output.exists(): + folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True) + if len(folders) > 0: + folder = folders[0] + return str((folder / "artifacts").absolute()) + msg = f"Could not infer data directory from root={root}" + raise ValueError(msg) + + +def _create_graphrag_config( + root: str | None, + config_dir: str | None, +) -> GraphRagConfig: + """Create a GraphRag configuration.""" + return _read_config_parameters(root or "./", config_dir) + + +def _read_config_parameters(root: str, config: str | None): + _root = Path(root) + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) + if not settings_yaml.exists(): + settings_yaml = _root / "settings.yml" + + if settings_yaml.exists(): + reporter.info(f"Reading settings from {settings_yaml}") + with settings_yaml.open( + "rb", + ) as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) + if settings_json.exists(): + reporter.info(f"Reading settings from {settings_json}") + with settings_json.open("rb") as file: + import json + + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + reporter.info("Reading settings from environment variables") + return create_graphrag_config(root_dir=root) diff --git a/func-app/graphrag/query/context_builder/__init__.py b/func-app/graphrag/query/context_builder/__init__.py new file mode 100644 index 0000000000..7e27364e1e --- /dev/null +++ b/func-app/graphrag/query/context_builder/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Functions to build context for system prompt to generate responses for a user query.""" diff --git a/func-app/graphrag/query/context_builder/builders.py b/func-app/graphrag/query/context_builder/builders.py new file mode 100644 index 0000000000..7a4ba277ae --- /dev/null +++ b/func-app/graphrag/query/context_builder/builders.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for global and local context builders.""" + +from abc import ABC, abstractmethod + +import pandas as pd + +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) + + +class GlobalContextBuilder(ABC): + """Base class for global-search context builders.""" + + @abstractmethod + def build_context( + self, conversation_history: ConversationHistory | None = None, **kwargs + ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: + """Build the context for the global search mode.""" + + +class LocalContextBuilder(ABC): + """Base class for local-search context builders.""" + + @abstractmethod + def build_context( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: + """Build the context for the local search mode.""" diff --git a/func-app/graphrag/query/context_builder/community_context.py b/func-app/graphrag/query/context_builder/community_context.py new file mode 100644 index 0000000000..ad345a2704 --- /dev/null +++ b/func-app/graphrag/query/context_builder/community_context.py @@ -0,0 +1,253 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Community Context.""" + +import logging +import random +from typing import Any, cast + +import pandas as pd +import tiktoken + +from graphrag.model import CommunityReport, Entity +from graphrag.query.llm.text_utils import num_tokens + +log = logging.getLogger(__name__) + + +def build_community_context( + community_reports: list[CommunityReport], + entities: list[Entity] | None = None, + token_encoder: tiktoken.Encoding | None = None, + use_community_summary: bool = True, + column_delimiter: str = "|", + shuffle_data: bool = True, + include_community_rank: bool = False, + min_community_rank: int = 0, + community_rank_name: str = "rank", + include_community_weight: bool = True, + community_weight_name: str = "occurrence weight", + normalize_community_weight: bool = True, + max_tokens: int = 8000, + single_batch: bool = True, + context_name: str = "Reports", + random_state: int = 86, + is_optimized_search: bool = False, +) -> tuple[str | list[str], dict[str, pd.DataFrame]]: + """ + Prepare community report data table as context data for system prompt. + + If entities are provided, the community weight is calculated as the count of text units associated with entities within the community. + + The calculated weight is added as an attribute to the community reports and added to the context data table. + """ + + def _is_included(report: CommunityReport) -> bool: + return report.rank is not None and report.rank >= min_community_rank + + def _get_header(attributes: list[str]) -> list[str]: + header = ["id", "title"] + attributes = [col for col in attributes if col not in header] + if not include_community_weight: + attributes = [col for col in attributes if col != community_weight_name] + header.extend(attributes) + header.append("summary" if use_community_summary else "content") + if include_community_rank: + header.append(community_rank_name) + return header + + def _report_context_text( + report: CommunityReport, attributes: list[str], + is_optimized_search: bool = False + ) -> tuple[str, list[str]]: + context: list[str] = [ + report.short_id if report.short_id else "", + report.title, + *[ + str(report.attributes.get(field, "")) if report.attributes else "" + for field in attributes + ], + ] + context.append(report.summary if use_community_summary else report.full_content) + if include_community_rank: + context.append(str(report.rank)) + result = column_delimiter.join(context) + "\n" + return result, context + + compute_community_weights = ( + entities + and len(community_reports) > 0 + and include_community_weight + and ( + community_reports[0].attributes is None + or community_weight_name not in community_reports[0].attributes + ) + ) + if compute_community_weights: + log.info("Computing community weights...") + community_reports = _compute_community_weights( + community_reports=community_reports, + entities=entities, + weight_attribute=community_weight_name, + normalize=normalize_community_weight, + ) + + selected_reports = [report for report in community_reports if _is_included(report)] + + if selected_reports is None or len(selected_reports) == 0: + return ([], {}) + + if shuffle_data: + random.seed(random_state) + random.shuffle(selected_reports) + + # "global" variables + attributes = ( + list(community_reports[0].attributes.keys()) + if community_reports[0].attributes + else [] + ) + header = _get_header(attributes) + all_context_text: list[str] = [] + all_context_records: list[pd.DataFrame] = [] + + # batch variables + batch_text: str = "" + batch_tokens: int = 0 + batch_records: list[list[str]] = [] + + def _init_batch() -> None: + nonlocal batch_text, batch_tokens, batch_records + batch_text = ( + f"-----{context_name}-----" + "\n" + column_delimiter.join(header) + "\n" + ) + batch_tokens = num_tokens(batch_text, token_encoder) + batch_records = [] + + def _cut_batch() -> None: + # convert the current context records to pandas dataframe and sort by weight and rank if exist + record_df = _convert_report_context_to_df( + context_records=batch_records, + header=header, + weight_column=community_weight_name + if entities and include_community_weight + else None, + rank_column=community_rank_name if include_community_rank else None, + ) + if len(record_df) == 0: + return + current_context_text = record_df.to_csv(index=False, sep=column_delimiter) + all_context_text.append(current_context_text) + all_context_records.append(record_df) + + # initialize the first batch + _init_batch() + + for report in selected_reports: + new_context_text, new_context = _report_context_text(report, attributes, is_optimized_search) + new_tokens = num_tokens(new_context_text, token_encoder) + + if batch_tokens + new_tokens > max_tokens: + # add the current batch to the context data and start a new batch if we are in multi-batch mode + _cut_batch() + if single_batch: + break + _init_batch() + + # add current report to the current batch + batch_text += new_context_text + batch_tokens += new_tokens + batch_records.append(new_context) + + # add the last batch if it has not been added + if batch_text not in all_context_text: + _cut_batch() + + if len(all_context_records) == 0: + log.warning( + "Warning: No community records added when building community context." + ) + return ([], {}) + + return all_context_text, { + context_name.lower(): pd.concat(all_context_records, ignore_index=True) + } + + +def _compute_community_weights( + community_reports: list[CommunityReport], + entities: list[Entity] | None, + weight_attribute: str = "occurrence", + normalize: bool = True, +) -> list[CommunityReport]: + """Calculate a community's weight as count of text units associated with entities within the community.""" + if not entities: + return community_reports + + community_text_units = {} + for entity in entities: + if entity.community_ids: + for community_id in entity.community_ids: + if community_id not in community_text_units: + community_text_units[community_id] = [] + community_text_units[community_id].extend(entity.text_unit_ids) + for report in community_reports: + if not report.attributes: + report.attributes = {} + report.attributes[weight_attribute] = len( + set(community_text_units.get(report.community_id, [])) + ) + if normalize: + # normalize by max weight + all_weights = [ + report.attributes[weight_attribute] + for report in community_reports + if report.attributes + ] + max_weight = max(all_weights) + for report in community_reports: + if report.attributes: + report.attributes[weight_attribute] = ( + report.attributes[weight_attribute] / max_weight + ) + return community_reports + + +def _rank_report_context( + report_df: pd.DataFrame, + weight_column: str | None = "occurrence weight", + rank_column: str | None = "rank", +) -> pd.DataFrame: + """Sort report context by community weight and rank if exist.""" + rank_attributes: list[str] = [] + if weight_column: + rank_attributes.append(weight_column) + report_df[weight_column] = report_df[weight_column].astype(float) + if rank_column: + rank_attributes.append(rank_column) + report_df[rank_column] = report_df[rank_column].astype(float) + if len(rank_attributes) > 0: + report_df.sort_values(by=rank_attributes, ascending=False, inplace=True) + return report_df + + +def _convert_report_context_to_df( + context_records: list[list[str]], + header: list[str], + weight_column: str | None = None, + rank_column: str | None = None, +) -> pd.DataFrame: + """Convert report context records to pandas dataframe and sort by weight and rank if exist.""" + if len(context_records) == 0: + return pd.DataFrame() + + record_df = pd.DataFrame( + context_records, + columns=cast(Any, header), + ) + return _rank_report_context( + report_df=record_df, + weight_column=weight_column, + rank_column=rank_column, + ) diff --git a/func-app/graphrag/query/context_builder/conversation_history.py b/func-app/graphrag/query/context_builder/conversation_history.py new file mode 100644 index 0000000000..33f516dbd4 --- /dev/null +++ b/func-app/graphrag/query/context_builder/conversation_history.py @@ -0,0 +1,212 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Classes for storing and managing conversation history.""" + +from dataclasses import dataclass +from enum import Enum + +import pandas as pd +import tiktoken + +from graphrag.query.llm.text_utils import num_tokens + +""" +Enum for conversation roles +""" + + +class ConversationRole(str, Enum): + """Enum for conversation roles.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + + @staticmethod + def from_string(value: str) -> "ConversationRole": + """Convert string to ConversationRole.""" + if value == "system": + return ConversationRole.SYSTEM + if value == "user": + return ConversationRole.USER + if value == "assistant": + return ConversationRole.ASSISTANT + + msg = f"Invalid Role: {value}" + raise ValueError(msg) + + def __str__(self) -> str: + """Return string representation of the enum value.""" + return self.value + + +""" +Data class for storing a single conversation turn +""" + + +@dataclass +class ConversationTurn: + """Data class for storing a single conversation turn.""" + + role: ConversationRole + content: str + + def __str__(self) -> str: + """Return string representation of the conversation turn.""" + return f"{self.role}: {self.content}" + + +@dataclass +class QATurn: + """ + Data class for storing a QA turn. + + A QA turn contains a user question and one more multiple assistant answers. + """ + + user_query: ConversationTurn + assistant_answers: list[ConversationTurn] | None = None + + def get_answer_text(self) -> str | None: + """Get the text of the assistant answers.""" + return ( + "\n".join([answer.content for answer in self.assistant_answers]) + if self.assistant_answers + else None + ) + + def __str__(self) -> str: + """Return string representation of the QA turn.""" + answers = self.get_answer_text() + return ( + f"Question: {self.user_query.content}\nAnswer: {answers}" + if answers + else f"Question: {self.user_query.content}" + ) + + +class ConversationHistory: + """Class for storing a conversation history.""" + + turns: list[ConversationTurn] + + def __init__(self): + self.turns = [] + + @classmethod + def from_list( + cls, conversation_turns: list[dict[str, str]] + ) -> "ConversationHistory": + """ + Create a conversation history from a list of conversation turns. + + Each turn is a dictionary in the form of {"role": "", "content": ""} + """ + history = cls() + for turn in conversation_turns: + history.turns.append( + ConversationTurn( + role=ConversationRole.from_string( + turn.get("role", ConversationRole.USER) + ), + content=turn.get("content", ""), + ) + ) + return history + + def add_turn(self, role: ConversationRole, content: str): + """Add a new turn to the conversation history.""" + self.turns.append(ConversationTurn(role=role, content=content)) + + def to_qa_turns(self) -> list[QATurn]: + """Convert conversation history to a list of QA turns.""" + qa_turns = list[QATurn]() + current_qa_turn = None + for turn in self.turns: + if turn.role == ConversationRole.USER: + if current_qa_turn: + qa_turns.append(current_qa_turn) + current_qa_turn = QATurn(user_query=turn, assistant_answers=[]) + else: + if current_qa_turn: + current_qa_turn.assistant_answers.append(turn) # type: ignore + if current_qa_turn: + qa_turns.append(current_qa_turn) + return qa_turns + + def get_user_turns(self, max_user_turns: int | None = 1) -> list[str]: + """Get the last user turns in the conversation history.""" + user_turns = [] + for turn in self.turns[::-1]: + if turn.role == ConversationRole.USER: + user_turns.append(turn.content) + if max_user_turns and len(user_turns) >= max_user_turns: + break + return user_turns + + def build_context( + self, + token_encoder: tiktoken.Encoding | None = None, + include_user_turns_only: bool = True, + max_qa_turns: int | None = 5, + max_tokens: int = 8000, + recency_bias: bool = True, + column_delimiter: str = "|", + context_name: str = "Conversation History", + ) -> tuple[str, dict[str, pd.DataFrame]]: + """ + Prepare conversation history as context data for system prompt. + + Parameters + ---------- + user_queries_only: If True, only user queries (not assistant responses) will be included in the context, default is True. + max_qa_turns: Maximum number of QA turns to include in the context, default is 1. + recency_bias: If True, reverse the order of the conversation history to ensure last QA got prioritized. + column_delimiter: Delimiter to use for separating columns in the context data, default is "|". + context_name: Name of the context, default is "Conversation History". + + """ + qa_turns = self.to_qa_turns() + if include_user_turns_only: + qa_turns = [ + QATurn(user_query=qa_turn.user_query, assistant_answers=None) + for qa_turn in qa_turns + ] + if recency_bias: + qa_turns = qa_turns[::-1] + if max_qa_turns and len(qa_turns) > max_qa_turns: + qa_turns = qa_turns[:max_qa_turns] + + # build context for qa turns + # add context header + if len(qa_turns) == 0 or not qa_turns: + return ("", {context_name: pd.DataFrame()}) + + # add table header + header = f"-----{context_name}-----" + "\n" + + turn_list = [] + current_context_df = pd.DataFrame() + for turn in qa_turns: + turn_list.append({ + "turn": ConversationRole.USER.__str__(), + "content": turn.user_query.content, + }) + if turn.assistant_answers: + turn_list.append({ + "turn": ConversationRole.ASSISTANT.__str__(), + "content": turn.get_answer_text(), + }) + + context_df = pd.DataFrame(turn_list) + context_text = header + context_df.to_csv(sep=column_delimiter, index=False) + if num_tokens(context_text, token_encoder) > max_tokens: + break + + current_context_df = context_df + context_text = header + current_context_df.to_csv( + sep=column_delimiter, index=False + ) + return (context_text, {context_name.lower(): current_context_df}) diff --git a/func-app/graphrag/query/context_builder/entity_extraction.py b/func-app/graphrag/query/context_builder/entity_extraction.py new file mode 100644 index 0000000000..037da80932 --- /dev/null +++ b/func-app/graphrag/query/context_builder/entity_extraction.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Orchestration Context Builders.""" + +from enum import Enum + +from graphrag.model import Entity, Relationship +from graphrag.query.input.retrieval.entities import ( + get_entity_by_key, + get_entity_by_name, +) +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.vector_stores import BaseVectorStore + + +class EntityVectorStoreKey(str, Enum): + """Keys used as ids in the entity embedding vectorstores.""" + + ID = "id" + TITLE = "title" + + @staticmethod + def from_string(value: str) -> "EntityVectorStoreKey": + """Convert string to EntityVectorStoreKey.""" + if value == "id": + return EntityVectorStoreKey.ID + if value == "title": + return EntityVectorStoreKey.TITLE + + msg = f"Invalid EntityVectorStoreKey: {value}" + raise ValueError(msg) + +def map_query_to_entities_in_place( + query: str, + text_embedding_vectorstore: BaseVectorStore, + text_embedder: BaseTextEmbedding, + k: int = 10, + oversample_scaler: int = 2, +) -> list[Entity]: + """Extract entities that match a given query using semantic similarity of text embeddings of query and entity descriptions.""" + # get entities with highest semantic similarity to query + # oversample to account for excluded entities + search_results = text_embedding_vectorstore.get_extracted_entities( + text=query, + text_embedder=lambda t: text_embedder.embed(t), + k=k * oversample_scaler, + ) + import ast + for result in search_results: + result.community_ids = ast.literal_eval(result.community_ids) + return search_results + +def map_query_to_entities( + query: str, + text_embedding_vectorstore: BaseVectorStore, + text_embedder: BaseTextEmbedding, + all_entities: list[Entity], + embedding_vectorstore_key: str = EntityVectorStoreKey.ID, + include_entity_names: list[str] | None = None, + exclude_entity_names: list[str] | None = None, + k: int = 10, + oversample_scaler: int = 2, +) -> list[Entity]: + """Extract entities that match a given query using semantic similarity of text embeddings of query and entity descriptions.""" + if all_entities == []: + return map_query_to_entities_in_place( + query, + text_embedding_vectorstore, + text_embedder, + k, + oversample_scaler, + ) + + if include_entity_names is None: + include_entity_names = [] + if exclude_entity_names is None: + exclude_entity_names = [] + matched_entities = [] + if query != "": + # get entities with highest semantic similarity to query + # oversample to account for excluded entities + search_results = text_embedding_vectorstore.similarity_search_by_text( + text=query, + text_embedder=lambda t: text_embedder.embed(t), + k=k * oversample_scaler, + ) + for result in search_results: + matched = get_entity_by_key( + entities=all_entities, + key=embedding_vectorstore_key, + value=result.document.id, + ) + if matched: + matched_entities.append(matched) + else: + all_entities.sort(key=lambda x: x.rank if x.rank else 0, reverse=True) + matched_entities = all_entities[:k] + + # filter out excluded entities + if exclude_entity_names: + matched_entities = [ + entity + for entity in matched_entities + if entity.title not in exclude_entity_names + ] + + # add entities in the include_entity list + included_entities = [] + for entity_name in include_entity_names: + included_entities.extend(get_entity_by_name(all_entities, entity_name)) + return included_entities + matched_entities + + +def find_nearest_neighbors_by_graph_embeddings( + entity_id: str, + graph_embedding_vectorstore: BaseVectorStore, + all_entities: list[Entity], + exclude_entity_names: list[str] | None = None, + embedding_vectorstore_key: str = EntityVectorStoreKey.ID, + k: int = 10, + oversample_scaler: int = 2, +) -> list[Entity]: + """Retrieve related entities by graph embeddings.""" + if exclude_entity_names is None: + exclude_entity_names = [] + # find nearest neighbors of this entity using graph embedding + query_entity = get_entity_by_key( + entities=all_entities, key=embedding_vectorstore_key, value=entity_id + ) + query_embedding = query_entity.graph_embedding if query_entity else None + + # oversample to account for excluded entities + if query_embedding: + matched_entities = [] + search_results = graph_embedding_vectorstore.similarity_search_by_vector( + query_embedding=query_embedding, k=k * oversample_scaler + ) + for result in search_results: + matched = get_entity_by_key( + entities=all_entities, + key=embedding_vectorstore_key, + value=result.document.id, + ) + if matched: + matched_entities.append(matched) + + # filter out excluded entities + if exclude_entity_names: + matched_entities = [ + entity + for entity in matched_entities + if entity.title not in exclude_entity_names + ] + matched_entities.sort(key=lambda x: x.rank, reverse=True) + return matched_entities[:k] + + return [] + + +def find_nearest_neighbors_by_entity_rank( + entity_name: str, + all_entities: list[Entity], + all_relationships: list[Relationship], + exclude_entity_names: list[str] | None = None, + k: int | None = 10, +) -> list[Entity]: + """Retrieve entities that have direct connections with the target entity, sorted by entity rank.""" + if exclude_entity_names is None: + exclude_entity_names = [] + entity_relationships = [ + rel + for rel in all_relationships + if rel.source == entity_name or rel.target == entity_name + ] + source_entity_names = {rel.source for rel in entity_relationships} + target_entity_names = {rel.target for rel in entity_relationships} + related_entity_names = (source_entity_names.union(target_entity_names)).difference( + set(exclude_entity_names) + ) + top_relations = [ + entity for entity in all_entities if entity.title in related_entity_names + ] + top_relations.sort(key=lambda x: x.rank if x.rank else 0, reverse=True) + if k: + return top_relations[:k] + return top_relations diff --git a/func-app/graphrag/query/context_builder/local_context.py b/func-app/graphrag/query/context_builder/local_context.py new file mode 100644 index 0000000000..b48bf0bf96 --- /dev/null +++ b/func-app/graphrag/query/context_builder/local_context.py @@ -0,0 +1,360 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Local Context Builder.""" + +from collections import defaultdict +from typing import Any, cast + +import pandas as pd +from common.graph_db_client import GraphDBClient +import tiktoken + +from graphrag.model import Covariate, Entity, Relationship +from graphrag.query.input.retrieval.covariates import ( + get_candidate_covariates, + to_covariate_dataframe, +) +from graphrag.query.input.retrieval.entities import to_entity_dataframe +from graphrag.query.input.retrieval.relationships import ( + get_candidate_relationships, + get_entities_from_relationships, + get_in_network_relationships, + get_out_network_relationships, + to_relationship_dataframe, +) +from graphrag.query.llm.text_utils import num_tokens + + +def build_entity_context( + selected_entities: list[Entity], + token_encoder: tiktoken.Encoding | None = None, + max_tokens: int = 8000, + include_entity_rank: bool = True, + rank_description: str = "number of relationships", + column_delimiter: str = "|", + context_name="Entities", + is_optimized_search: bool = False +) -> tuple[str, pd.DataFrame]: + """Prepare entity data table as context data for system prompt.""" + if len(selected_entities) == 0: + return "", pd.DataFrame() + + # add headers + current_context_text = f"-----{context_name}-----" + "\n" + header = ["id", "entity", "description"] + if include_entity_rank: + header.append(rank_description) + attribute_cols = ( + list(selected_entities[0].attributes.keys()) + if selected_entities[0].attributes + else [] + ) + header.extend(attribute_cols) + current_context_text += column_delimiter.join(header) + "\n" + current_tokens = num_tokens(current_context_text, token_encoder) + + all_context_records = [header] + for entity in selected_entities: + new_context = [ + entity.short_id if entity.short_id else "", + entity.title, + entity.description if entity.description else "", + ] + if include_entity_rank: + new_context.append(str(entity.rank)) + for field in attribute_cols: + field_value = ( + str(entity.attributes.get(field)) + if entity.attributes and entity.attributes.get(field) + else "" + ) + new_context.append(field_value) + new_tokens: int = 0 + if not is_optimized_search: + new_context_text = column_delimiter.join(new_context) + "\n" + new_tokens = num_tokens(new_context_text, token_encoder) + if current_tokens + new_tokens > max_tokens: + break + current_context_text += new_context_text + all_context_records.append(new_context) + current_tokens += new_tokens + + if len(all_context_records) > 1: + record_df = pd.DataFrame( + all_context_records[1:], columns=cast(Any, all_context_records[0]) + ) + else: + record_df = pd.DataFrame() + + return current_context_text, record_df + + +def build_covariates_context( + selected_entities: list[Entity], + covariates: list[Covariate], + token_encoder: tiktoken.Encoding | None = None, + max_tokens: int = 8000, + column_delimiter: str = "|", + context_name: str = "Covariates", + is_optimized_search: bool = False +) -> tuple[str, pd.DataFrame]: + """Prepare covariate data tables as context data for system prompt.""" + # create an empty list of covariates + if len(selected_entities) == 0 or len(covariates) == 0: + return "", pd.DataFrame() + + selected_covariates = list[Covariate]() + record_df = pd.DataFrame() + + # add context header + current_context_text = f"-----{context_name}-----" + "\n" + + # add header + header = ["id", "entity"] + attributes = covariates[0].attributes or {} if len(covariates) > 0 else {} + attribute_cols = list(attributes.keys()) if len(covariates) > 0 else [] + header.extend(attribute_cols) + current_context_text += column_delimiter.join(header) + "\n" + current_tokens = num_tokens(current_context_text, token_encoder) + + all_context_records = [header] + for entity in selected_entities: + selected_covariates.extend([ + cov for cov in covariates if cov.subject_id == entity.title + ]) + + for covariate in selected_covariates: + new_context = [ + covariate.short_id if covariate.short_id else "", + covariate.subject_id, + ] + for field in attribute_cols: + field_value = ( + str(covariate.attributes.get(field)) + if covariate.attributes and covariate.attributes.get(field) + else "" + ) + new_context.append(field_value) + + new_context_text = column_delimiter.join(new_context) + "\n" + new_tokens = num_tokens(new_context_text, token_encoder) + if current_tokens + new_tokens > max_tokens: + break + current_context_text += new_context_text + all_context_records.append(new_context) + current_tokens += new_tokens + + if len(all_context_records) > 1: + record_df = pd.DataFrame( + all_context_records[1:], columns=cast(Any, all_context_records[0]) + ) + else: + record_df = pd.DataFrame() + + return current_context_text, record_df + + +def build_relationship_context( + selected_entities: list[Entity], + relationships: list[Relationship], + token_encoder: tiktoken.Encoding | None = None, + include_relationship_weight: bool = False, + max_tokens: int = 8000, + top_k_relationships: int = 10, + relationship_ranking_attribute: str = "rank", + column_delimiter: str = "|", + context_name: str = "Relationships", + is_optimized_search: bool = False, + graphdb_client: GraphDBClient|None=None, +) -> tuple[str, pd.DataFrame]: + """Prepare relationship data tables as context data for system prompt.""" + selected_relationships = _filter_relationships( + selected_entities=selected_entities, + relationships=relationships, + top_k_relationships=top_k_relationships, + relationship_ranking_attribute=relationship_ranking_attribute, + graphdb_client=graphdb_client, + ) + + if len(selected_entities) == 0 or len(selected_relationships) == 0: + return "", pd.DataFrame() + + # add headers + current_context_text = f"-----{context_name}-----" + "\n" + header = ["id", "source", "target", "description"] + if include_relationship_weight: + header.append("weight") + attribute_cols = ( + list(selected_relationships[0].attributes.keys()) + if selected_relationships[0].attributes + else [] + ) + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + current_context_text += column_delimiter.join(header) + "\n" + current_tokens = num_tokens(current_context_text, token_encoder) + + all_context_records = [header] + for rel in selected_relationships: + new_context = [ + rel.short_id if rel.short_id else "", + rel.source, + rel.target, + rel.description if rel.description else "", + ] + if include_relationship_weight: + new_context.append(str(rel.weight if rel.weight else "")) + for field in attribute_cols: + field_value = ( + str(rel.attributes.get(field)) + if rel.attributes and rel.attributes.get(field) + else "" + ) + new_context.append(field_value) + new_context_text = "" + new_tokens = 0 + if not is_optimized_search: + new_context_text = column_delimiter.join(new_context) + "\n" + new_tokens = num_tokens(new_context_text, token_encoder) + if current_tokens + new_tokens > max_tokens: #General: There could be side impact of generating huge number of relationships + break + current_context_text += new_context_text + all_context_records.append(new_context) + current_tokens += new_tokens + + if len(all_context_records) > 1: + record_df = pd.DataFrame( + all_context_records[1:], columns=cast(Any, all_context_records[0]) + ) + else: + record_df = pd.DataFrame() + + return current_context_text, record_df + + +def _filter_relationships( + selected_entities: list[Entity], + relationships: list[Relationship], + top_k_relationships: int = 10, + relationship_ranking_attribute: str = "rank", + graphdb_client: GraphDBClient|None=None, +) -> list[Relationship]: + """Filter and sort relationships based on a set of selected entities and a ranking attribute.""" + # First priority: in-network relationships (i.e. relationships between selected entities) + in_network_relationships = get_in_network_relationships( + selected_entities=selected_entities, + relationships=relationships, + ranking_attribute=relationship_ranking_attribute, + graphdb_client=graphdb_client, + ) + + # Second priority - out-of-network relationships + # (i.e. relationships between selected entities and other entities that are not within the selected entities) + out_network_relationships = get_out_network_relationships( + selected_entities=selected_entities, + relationships=relationships, + ranking_attribute=relationship_ranking_attribute, + graphdb_client=graphdb_client, + ) + if len(out_network_relationships) <= 1: + return in_network_relationships + out_network_relationships + + # within out-of-network relationships, prioritize mutual relationships + # (i.e. relationships with out-network entities that are shared with multiple selected entities) + selected_entity_names = [entity.title for entity in selected_entities] + out_network_source_names = [ + relationship.source + for relationship in out_network_relationships + if relationship.source not in selected_entity_names + ] + out_network_target_names = [ + relationship.target + for relationship in out_network_relationships + if relationship.target not in selected_entity_names + ] + out_network_entity_names = list( + set(out_network_source_names + out_network_target_names) + ) + out_network_entity_links = defaultdict(int) + for entity_name in out_network_entity_names: + targets = [ + relationship.target + for relationship in out_network_relationships + if relationship.source == entity_name + ] + sources = [ + relationship.source + for relationship in out_network_relationships + if relationship.target == entity_name + ] + out_network_entity_links[entity_name] = len(set(targets + sources)) + + # sort out-network relationships by number of links and rank_attributes + for rel in out_network_relationships: + if rel.attributes is None: + rel.attributes = {} + rel.attributes["links"] = ( + out_network_entity_links[rel.source] + if rel.source in out_network_entity_links + else out_network_entity_links[rel.target] + ) + + # sort by attributes[links] first, then by ranking_attribute + if relationship_ranking_attribute == "weight": + out_network_relationships.sort( + key=lambda x: (x.attributes["links"], x.weight), # type: ignore + reverse=True, # type: ignore + ) + else: + out_network_relationships.sort( + key=lambda x: ( + x.attributes["links"], # type: ignore + x.attributes[relationship_ranking_attribute], # type: ignore + ), # type: ignore + reverse=True, + ) + + relationship_budget = top_k_relationships * len(selected_entities) + return in_network_relationships + out_network_relationships[:relationship_budget] + + +def get_candidate_context( + selected_entities: list[Entity], + entities: list[Entity], + relationships: list[Relationship], + covariates: dict[str, list[Covariate]], + include_entity_rank: bool = True, + entity_rank_description: str = "number of relationships", + include_relationship_weight: bool = False, +) -> dict[str, pd.DataFrame]: + """Prepare entity, relationship, and covariate data tables as context data for system prompt.""" + candidate_context = {} + candidate_relationships = get_candidate_relationships( + selected_entities=selected_entities, + relationships=relationships, + ) + candidate_context["relationships"] = to_relationship_dataframe( + relationships=candidate_relationships, + include_relationship_weight=include_relationship_weight, + ) + candidate_entities = get_entities_from_relationships( + relationships=candidate_relationships, entities=entities + ) + candidate_context["entities"] = to_entity_dataframe( + entities=candidate_entities, + include_entity_rank=include_entity_rank, + rank_description=entity_rank_description, + ) + + for covariate in covariates: + candidate_covariates = get_candidate_covariates( + selected_entities=selected_entities, + covariates=covariates[covariate], + ) + candidate_context[covariate.lower()] = to_covariate_dataframe( + candidate_covariates + ) + + return candidate_context diff --git a/func-app/graphrag/query/context_builder/source_context.py b/func-app/graphrag/query/context_builder/source_context.py new file mode 100644 index 0000000000..99b7791ca6 --- /dev/null +++ b/func-app/graphrag/query/context_builder/source_context.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Context Build utility methods.""" + +import random +from typing import Any, cast + +import pandas as pd +import tiktoken + +from graphrag.model import Entity, Relationship, TextUnit +from graphrag.query.llm.text_utils import num_tokens + +""" +Contain util functions to build text unit context for the search's system prompt +""" + + +def build_text_unit_context( + text_units: list[TextUnit], + token_encoder: tiktoken.Encoding | None = None, + column_delimiter: str = "|", + shuffle_data: bool = True, + max_tokens: int = 8000, + context_name: str = "Sources", + random_state: int = 86, +) -> tuple[str, dict[str, pd.DataFrame]]: + """Prepare text-unit data table as context data for system prompt.""" + if text_units is None or len(text_units) == 0: + return ("", {}) + + if shuffle_data: + random.seed(random_state) + random.shuffle(text_units) + + # add context header + current_context_text = f"-----{context_name}-----" + "\n" + + # add header + header = ["id", "text"] + attribute_cols = ( + list(text_units[0].attributes.keys()) if text_units[0].attributes else [] + ) + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + current_context_text += column_delimiter.join(header) + "\n" + current_tokens = num_tokens(current_context_text, token_encoder) + all_context_records = [header] + + for unit in text_units: + new_context = [ + unit.short_id, + unit.text, + *[ + str(unit.attributes.get(field, "")) if unit.attributes else "" + for field in attribute_cols + ], + ] + new_context_text = column_delimiter.join(new_context) + "\n" + new_tokens = num_tokens(new_context_text, token_encoder) + + if current_tokens + new_tokens > max_tokens: + break + + current_context_text += new_context_text + all_context_records.append(new_context) + current_tokens += new_tokens + + if len(all_context_records) > 1: + record_df = pd.DataFrame( + all_context_records[1:], columns=cast(Any, all_context_records[0]) + ) + else: + record_df = pd.DataFrame() + return current_context_text, {context_name.lower(): record_df} + + +def count_relationships( + text_unit: TextUnit, entity: Entity, relationships: dict[str, Relationship] +) -> int: + """Count the number of relationships of the selected entity that are associated with the text unit.""" + matching_relationships = list[Relationship]() + if text_unit.relationship_ids is None: + entity_relationships = [ + rel + for rel in relationships.values() + if rel.source == entity.title or rel.target == entity.title + ] + entity_relationships = [ + rel for rel in entity_relationships if rel.text_unit_ids + ] + matching_relationships = [ + rel + for rel in entity_relationships + if text_unit.id in rel.text_unit_ids # type: ignore + ] # type: ignore + else: + text_unit_relationships = [ + relationships[rel_id] + for rel_id in text_unit.relationship_ids + if rel_id in relationships + ] + matching_relationships = [ + rel + for rel in text_unit_relationships + if rel.source == entity.title or rel.target == entity.title + ] + return len(matching_relationships) diff --git a/func-app/graphrag/query/factories.py b/func-app/graphrag/query/factories.py new file mode 100644 index 0000000000..28caf61bb0 --- /dev/null +++ b/func-app/graphrag/query/factories.py @@ -0,0 +1,211 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Query Factory methods to support CLI.""" + +from graphrag.config.models.graphdb_config import GraphDBConfig +import tiktoken +from azure.identity import ManagedIdentityCredential, get_bearer_token_provider + +from graphrag.config import ( + GraphRagConfig, + LLMType, +) +from graphrag.model import ( + CommunityReport, + Covariate, + Entity, + Relationship, + TextUnit, +) +from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey +from graphrag.query.llm.oai.chat_openai import ChatOpenAI +from graphrag.query.llm.oai.embedding import OpenAIEmbedding +from graphrag.query.llm.oai.typing import OpenaiApiType +from graphrag.query.structured_search.global_search.community_context import ( + GlobalCommunityContext, +) +from graphrag.query.structured_search.global_search.search import GlobalSearch +from graphrag.query.structured_search.local_search.mixed_context import ( + LocalSearchMixedContext, +) +from graphrag.query.structured_search.local_search.search import LocalSearch +from graphrag.vector_stores import BaseVectorStore + + +def get_llm(config: GraphRagConfig) -> ChatOpenAI: + """Get the LLM client.""" + is_azure_client = ( + config.llm.type == LLMType.AzureOpenAIChat + or config.llm.type == LLMType.AzureOpenAI + ) + debug_llm_key = config.llm.api_key or "" + llm_debug_info = { + **config.llm.model_dump(), + "api_key": f"REDACTED,len={len(debug_llm_key)}", + } + if config.llm.cognitive_services_endpoint is None: + cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default" + else: + cognitive_services_endpoint = config.llm.cognitive_services_endpoint + print(f"creating llm client with {llm_debug_info}") # noqa T201 + return ChatOpenAI( + api_key=config.llm.api_key, + azure_ad_token_provider=( + get_bearer_token_provider( + ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint + ) + if is_azure_client and not config.llm.api_key + else None + ), + api_base=config.llm.api_base, + organization=config.llm.organization, + model=config.llm.model, + api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI, + deployment_name=config.llm.deployment_name, + api_version=config.llm.api_version, + max_retries=config.llm.max_retries, + ) + + +def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding: + """Get the LLM client for embeddings.""" + is_azure_client = config.embeddings.llm.type == LLMType.AzureOpenAIEmbedding + debug_embedding_api_key = config.embeddings.llm.api_key or "" + llm_debug_info = { + **config.embeddings.llm.model_dump(), + "api_key": f"REDACTED,len={len(debug_embedding_api_key)}", + } + if config.embeddings.llm.cognitive_services_endpoint is None: + cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default" + else: + cognitive_services_endpoint = config.embeddings.llm.cognitive_services_endpoint + print(f"creating embedding llm client with {llm_debug_info}") # noqa T201 + return OpenAIEmbedding( + api_key=config.embeddings.llm.api_key, + azure_ad_token_provider=( + get_bearer_token_provider( + ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint + ) + if is_azure_client and not config.embeddings.llm.api_key + else None + ), + api_base=config.embeddings.llm.api_base, + organization=config.llm.organization, + api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI, + model=config.embeddings.llm.model, + deployment_name=config.embeddings.llm.deployment_name, + api_version=config.embeddings.llm.api_version, + max_retries=config.embeddings.llm.max_retries, + ) + + +def get_local_search_engine( + config: GraphRagConfig, + reports: list[CommunityReport], + text_units: list[TextUnit], + entities: list[Entity], + relationships: list[Relationship], + covariates: dict[str, list[Covariate]], + response_type: str, + description_embedding_store: BaseVectorStore, + context_id: str, + is_optimized_search: bool = False, + use_kusto_community_reports: bool = False, + graphdb_config: GraphDBConfig|None = None, +) -> LocalSearch: + """Create a local search engine based on data + configuration.""" + llm = get_llm(config) + text_embedder = get_text_embedder(config) + token_encoder = tiktoken.get_encoding(config.encoding_model) + + ls_config = config.local_search + + return LocalSearch( + llm=llm, + context_builder=LocalSearchMixedContext( + community_reports=reports, + text_units=text_units, + entities=entities, + relationships=relationships, + covariates=covariates, + entity_text_embeddings=description_embedding_store, + embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE + text_embedder=text_embedder, + token_encoder=token_encoder, + is_optimized_search= is_optimized_search, + use_kusto_community_reports=use_kusto_community_reports, + graphdb_config=graphdb_config, + context_id=context_id, + ), + token_encoder=token_encoder, + llm_params={ + "max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500) + "temperature": ls_config.temperature, + "top_p": ls_config.top_p, + "n": ls_config.n, + }, + context_builder_params={ + "text_unit_prop": ls_config.text_unit_prop, + "community_prop": ls_config.community_prop, + "conversation_history_max_turns": ls_config.conversation_history_max_turns, + "conversation_history_user_turns_only": True, + "top_k_mapped_entities": ls_config.top_k_entities, + "top_k_relationships": ls_config.top_k_relationships, + "include_entity_rank": True, + "include_relationship_weight": True, + "include_community_rank": False, + "return_candidate_context": False, + "embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids + "max_tokens": ls_config.max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + }, + response_type=response_type, + ) + + +def get_global_search_engine( + config: GraphRagConfig, + reports: list[CommunityReport], + entities: list[Entity], + response_type: str, +): + """Create a global search engine based on data + configuration.""" + token_encoder = tiktoken.get_encoding(config.encoding_model) + gs_config = config.global_search + + return GlobalSearch( + llm=get_llm(config), + context_builder=GlobalCommunityContext( + community_reports=reports, entities=entities, token_encoder=token_encoder + ), + token_encoder=token_encoder, + max_data_tokens=gs_config.data_max_tokens, + map_llm_params={ + "max_tokens": gs_config.map_max_tokens, + "temperature": gs_config.temperature, + "top_p": gs_config.top_p, + "n": gs_config.n, + }, + reduce_llm_params={ + "max_tokens": gs_config.reduce_max_tokens, + "temperature": gs_config.temperature, + "top_p": gs_config.top_p, + "n": gs_config.n, + }, + allow_general_knowledge=False, + json_mode=False, + context_builder_params={ + "use_community_summary": False, + "shuffle_data": True, + "include_community_rank": True, + "min_community_rank": 0, + "community_rank_name": "rank", + "include_community_weight": True, + "community_weight_name": "occurrence weight", + "normalize_community_weight": True, + "max_tokens": gs_config.max_tokens, + "context_name": "Reports", + }, + concurrent_coroutines=gs_config.concurrency, + response_type=response_type, + ) diff --git a/func-app/graphrag/query/indexer_adapters.py b/func-app/graphrag/query/indexer_adapters.py new file mode 100644 index 0000000000..101fc16f9c --- /dev/null +++ b/func-app/graphrag/query/indexer_adapters.py @@ -0,0 +1,159 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Indexing-Engine to Query Read Adapters. + +The parts of these functions that do type adaptation, renaming, collating, etc. should eventually go away. +Ideally this is just a straight read-thorugh into the object model. +""" + +from typing import cast + +import pandas as pd + +from graphrag.model import CommunityReport, Covariate, Entity, Relationship, TextUnit +from graphrag.query.input.loaders.dfs import ( + read_community_reports, + read_covariates, + read_entities, + read_relationships, + read_text_units, +) + +from graphrag.vector_stores import VectorStoreFactory, VectorStoreType + +def read_indexer_text_units(final_text_units: pd.DataFrame) -> list[TextUnit]: + """Read in the Text Units from the raw indexing outputs.""" + return read_text_units( + df=final_text_units, + short_id_col=None, + # expects a covariate map of type -> ids + covariates_col=None, + ) + + +def read_indexer_covariates(final_covariates: pd.DataFrame) -> list[Covariate]: + """Read in the Claims from the raw indexing outputs.""" + covariate_df = final_covariates + covariate_df["id"] = covariate_df["id"].astype(str) + return read_covariates( + df=covariate_df, + short_id_col="human_readable_id", + attributes_cols=[ + "object_id", + "status", + "start_date", + "end_date", + "description", + ], + text_unit_ids_col=None, + ) + +# GraphDB: read relationshiops from the graph db. +def read_indexer_relationships(final_relationships: pd.DataFrame) -> list[Relationship]: + """Read in the Relationships from the raw indexing outputs.""" + return read_relationships( + df=final_relationships, + short_id_col="human_readable_id", + description_embedding_col=None, + document_ids_col=None, + attributes_cols=["rank"], + ) + +def kt_read_indexer_reports( + vs: VectorStoreType.Kusto, + community_level: int, +) -> list[CommunityReport]: + + vs.client.execute(vs.database,'.drop table interm_rep ifexists') + + cmd=f''' + .set interm_rep <| (create_final_community_reports | where level <= 2 | + join kind=inner (create_final_nodes | + where level <= 2 | summarize community=max(community) by ['title'] | summarize by community ) + on community | project-away community1) + ''' + + res=vs.client.execute(vs.database,cmd) + return True #TODO: error checking should be added later + +def read_indexer_reports( + final_community_reports: pd.DataFrame, + final_nodes: pd.DataFrame, + community_level: int, +) -> list[CommunityReport]: + """Read in the Community Reports from the raw indexing outputs.""" + report_df = final_community_reports + entity_df = final_nodes + entity_df = _filter_under_community_level(entity_df, community_level) + entity_df["community"] = entity_df["community"].fillna(-1) + entity_df["community"] = entity_df["community"].astype(int) + + entity_df = entity_df.groupby(["title"]).agg({"community": "max"}).reset_index() + entity_df["community"] = entity_df["community"].astype(str) + filtered_community_df = entity_df["community"].drop_duplicates() + + report_df = _filter_under_community_level(report_df, community_level) + report_df = report_df.merge(filtered_community_df, on="community", how="inner") + report_df = report_df.drop_duplicates(subset=["community"]) + + return read_community_reports( + df=report_df, + id_col="community", + short_id_col="community", + summary_embedding_col=None, + content_embedding_col=None, + ) + + +def read_indexer_entities( + final_nodes: pd.DataFrame, + final_entities: pd.DataFrame, + community_level: int, +) -> list[Entity]: + """Read in the Entities from the raw indexing outputs.""" + entity_df = final_nodes + entity_embedding_df = final_entities + + entity_df = _filter_under_community_level(entity_df, community_level) + entity_df = cast(pd.DataFrame, entity_df[["title", "degree", "community"]]).rename( + columns={"title": "name", "degree": "rank"} + ) + + entity_df["community"] = entity_df["community"].fillna(-1) + entity_df["community"] = entity_df["community"].astype(int) + entity_df["rank"] = entity_df["rank"].astype(int) + + # for duplicate entities, keep the one with the highest community level + entity_df = ( + entity_df.groupby(["name", "rank"]).agg({"community": "max"}).reset_index() + ) + entity_df["community"] = entity_df["community"].apply(lambda x: [str(x)]) + entity_df = entity_df.merge( + entity_embedding_df, on="name", how="inner" + ).drop_duplicates(subset=["name"]) + + # read entity dataframe to knowledge model objects + return read_entities( + df=entity_df, + id_col="id", + title_col="name", + type_col="type", + short_id_col="human_readable_id", + description_col="description", + community_col="community", + rank_col="rank", + name_embedding_col=None, + description_embedding_col="description_embedding", + graph_embedding_col=None, + text_unit_ids_col="text_unit_ids", + document_ids_col=None, + ) + + +def _filter_under_community_level( + df: pd.DataFrame, community_level: int +) -> pd.DataFrame: + return cast( + pd.DataFrame, + df[df.level <= community_level], + ) diff --git a/func-app/graphrag/query/input/__init__.py b/func-app/graphrag/query/input/__init__.py new file mode 100644 index 0000000000..94ae973477 --- /dev/null +++ b/func-app/graphrag/query/input/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG Orchestration Inputs.""" diff --git a/func-app/graphrag/query/input/loaders/__init__.py b/func-app/graphrag/query/input/loaders/__init__.py new file mode 100644 index 0000000000..8f19dac0dd --- /dev/null +++ b/func-app/graphrag/query/input/loaders/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG Orchestartion Input Loaders.""" diff --git a/func-app/graphrag/query/input/loaders/dfs.py b/func-app/graphrag/query/input/loaders/dfs.py new file mode 100644 index 0000000000..7312963bb8 --- /dev/null +++ b/func-app/graphrag/query/input/loaders/dfs.py @@ -0,0 +1,340 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Load data from dataframes into collections of data objects.""" + +import pandas as pd + +from graphrag.model import ( + Community, + CommunityReport, + Covariate, + Document, + Entity, + Relationship, + TextUnit, +) +from graphrag.query.input.loaders.utils import ( + to_list, + to_optional_dict, + to_optional_float, + to_optional_int, + to_optional_list, + to_optional_str, + to_str, +) +from graphrag.vector_stores import BaseVectorStore, VectorStoreDocument + + +def read_entities( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + title_col: str = "title", + type_col: str | None = "type", + description_col: str | None = "description", + name_embedding_col: str | None = "name_embedding", + description_embedding_col: str | None = "description_embedding", + graph_embedding_col: str | None = "graph_embedding", + community_col: str | None = "community_ids", + text_unit_ids_col: str | None = "text_unit_ids", + document_ids_col: str | None = "document_ids", + rank_col: str | None = "degree", + attributes_cols: list[str] | None = None, +) -> list[Entity]: + """Read entities from a dataframe.""" + entities = [] + for idx, row in df.iterrows(): + entity = Entity( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + title=to_str(row, title_col), + type=to_optional_str(row, type_col), + description=to_optional_str(row, description_col), + name_embedding=to_optional_list(row, name_embedding_col, item_type=float), + description_embedding=to_optional_list( + row, description_embedding_col, item_type=float + ), + graph_embedding=to_optional_list(row, graph_embedding_col, item_type=float), + community_ids=to_optional_list(row, community_col, item_type=str), + text_unit_ids=to_optional_list(row, text_unit_ids_col), + document_ids=to_optional_list(row, document_ids_col), + rank=to_optional_int(row, rank_col), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + entities.append(entity) + return entities + + +def store_entity_semantic_embeddings( + entities: list[Entity], + vectorstore: BaseVectorStore, +) -> BaseVectorStore: + """Store entity semantic embeddings in a vectorstore.""" + documents = [ + VectorStoreDocument( + id=entity.id, + text=entity.description, + vector=entity.description_embedding, + attributes=( + {"title": entity.title, **entity.attributes} + if entity.attributes + else {"title": entity.title} + ), + ) + for entity in entities + ] + vectorstore.load_documents(documents=documents) + return vectorstore + + +def store_entity_behavior_embeddings( + entities: list[Entity], + vectorstore: BaseVectorStore, +) -> BaseVectorStore: + """Store entity behavior embeddings in a vectorstore.""" + documents = [ + VectorStoreDocument( + id=entity.id, + text=entity.description, + vector=entity.graph_embedding, + attributes=( + {"title": entity.title, **entity.attributes} + if entity.attributes + else {"title": entity.title} + ), + ) + for entity in entities + ] + vectorstore.load_documents(documents=documents) + return vectorstore + + +def read_relationships( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + source_col: str = "source", + target_col: str = "target", + description_col: str | None = "description", + description_embedding_col: str | None = "description_embedding", + weight_col: str | None = "weight", + text_unit_ids_col: str | None = "text_unit_ids", + document_ids_col: str | None = "document_ids", + attributes_cols: list[str] | None = None, +) -> list[Relationship]: + """Read relationships from a dataframe.""" + relationships = [] + for idx, row in df.iterrows(): + rel = Relationship( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + source=to_str(row, source_col), + target=to_str(row, target_col), + description=to_optional_str(row, description_col), + description_embedding=to_optional_list( + row, description_embedding_col, item_type=float + ), + weight=to_optional_float(row, weight_col), + text_unit_ids=to_optional_list(row, text_unit_ids_col, item_type=str), + document_ids=to_optional_list(row, document_ids_col, item_type=str), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + relationships.append(rel) + return relationships + + +def read_covariates( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + subject_col: str = "subject_id", + subject_type_col: str | None = "subject_type", + covariate_type_col: str | None = "covariate_type", + text_unit_ids_col: str | None = "text_unit_ids", + document_ids_col: str | None = "document_ids", + attributes_cols: list[str] | None = None, +) -> list[Covariate]: + """Read covariates from a dataframe.""" + covariates = [] + for idx, row in df.iterrows(): + cov = Covariate( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + subject_id=to_str(row, subject_col), + subject_type=( + to_str(row, subject_type_col) if subject_type_col else "entity" + ), + covariate_type=( + to_str(row, covariate_type_col) if covariate_type_col else "claim" + ), + text_unit_ids=to_optional_list(row, text_unit_ids_col, item_type=str), + document_ids=to_optional_list(row, document_ids_col, item_type=str), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + covariates.append(cov) + return covariates + + +def read_communities( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + title_col: str = "title", + level_col: str = "level", + entities_col: str | None = "entity_ids", + relationships_col: str | None = "relationship_ids", + covariates_col: str | None = "covariate_ids", + attributes_cols: list[str] | None = None, +) -> list[Community]: + """Read communities from a dataframe.""" + communities = [] + for idx, row in df.iterrows(): + comm = Community( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + title=to_str(row, title_col), + level=to_str(row, level_col), + entity_ids=to_optional_list(row, entities_col, item_type=str), + relationship_ids=to_optional_list(row, relationships_col, item_type=str), + covariate_ids=to_optional_dict( + row, covariates_col, key_type=str, value_type=str + ), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + communities.append(comm) + return communities + + +def read_community_reports( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + title_col: str = "title", + community_col: str = "community", + summary_col: str = "summary", + content_col: str = "full_content", + rank_col: str | None = "rank", + summary_embedding_col: str | None = "summary_embedding", + content_embedding_col: str | None = "full_content_embedding", + attributes_cols: list[str] | None = None, +) -> list[CommunityReport]: + """Read community reports from a dataframe.""" + reports = [] + for idx, row in df.iterrows(): + report = CommunityReport( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + title=to_str(row, title_col), + community_id=to_str(row, community_col), + summary=to_str(row, summary_col), + full_content=to_str(row, content_col), + rank=to_optional_float(row, rank_col), + summary_embedding=to_optional_list( + row, summary_embedding_col, item_type=float + ), + full_content_embedding=to_optional_list( + row, content_embedding_col, item_type=float + ), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + reports.append(report) + return reports + + +def read_text_units( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + text_col: str = "text", + entities_col: str | None = "entity_ids", + relationships_col: str | None = "relationship_ids", + covariates_col: str | None = "covariate_ids", + tokens_col: str | None = "n_tokens", + document_ids_col: str | None = "document_ids", + embedding_col: str | None = "text_embedding", + attributes_cols: list[str] | None = None, +) -> list[TextUnit]: + """Read text units from a dataframe.""" + text_units = [] + for idx, row in df.iterrows(): + chunk = TextUnit( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + text=to_str(row, text_col), + entity_ids=to_optional_list(row, entities_col, item_type=str), + relationship_ids=to_optional_list(row, relationships_col, item_type=str), + covariate_ids=to_optional_dict( + row, covariates_col, key_type=str, value_type=str + ), + text_embedding=to_optional_list(row, embedding_col, item_type=float), # type: ignore + n_tokens=to_optional_int(row, tokens_col), + document_ids=to_optional_list(row, document_ids_col, item_type=str), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + text_units.append(chunk) + return text_units + + +def read_documents( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str = "short_id", + title_col: str = "title", + type_col: str = "type", + summary_col: str | None = "entities", + raw_content_col: str | None = "relationships", + summary_embedding_col: str | None = "summary_embedding", + content_embedding_col: str | None = "raw_content_embedding", + text_units_col: str | None = "text_units", + attributes_cols: list[str] | None = None, +) -> list[Document]: + """Read documents from a dataframe.""" + docs = [] + for idx, row in df.iterrows(): + doc = Document( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + title=to_str(row, title_col), + type=to_str(row, type_col), + summary=to_optional_str(row, summary_col), + raw_content=to_str(row, raw_content_col), + summary_embedding=to_optional_list( + row, summary_embedding_col, item_type=float + ), + raw_content_embedding=to_optional_list( + row, content_embedding_col, item_type=float + ), + text_units=to_list(row, text_units_col, item_type=str), # type: ignore + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + docs.append(doc) + return docs diff --git a/func-app/graphrag/query/input/loaders/utils.py b/func-app/graphrag/query/input/loaders/utils.py new file mode 100644 index 0000000000..e0fffd2467 --- /dev/null +++ b/func-app/graphrag/query/input/loaders/utils.py @@ -0,0 +1,245 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Data load utils.""" + +import numpy as np +import pandas as pd + + +def to_str(data: pd.Series, column_name: str | None) -> str: + """Convert and validate a value to a string.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + return str(data[column_name]) + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + +def to_optional_str(data: pd.Series, column_name: str | None) -> str | None: + """Convert and validate a value to an optional string.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + value = data[column_name] + if value is None: + return None + return str(data[column_name]) + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + +def to_list( + data: pd.Series, column_name: str | None, item_type: type | None = None +) -> list: + """Convert and validate a value to a list.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + value = data[column_name] + if isinstance(value, np.ndarray): + value = value.tolist() + + if not isinstance(value, list): + msg = f"value is not a list: {value} ({type(value)})" + raise ValueError(msg) + + if item_type is not None: + for v in value: + if not isinstance(v, item_type): + msg = f"list item has item that is not {item_type}: {v} ({type(v)})" + raise TypeError(msg) + return value + + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + +def to_optional_list( + data: pd.Series, column_name: str | None, item_type: type | None = None +) -> list | None: + """Convert and validate a value to an optional list.""" + if column_name is None: + return None + + if column_name in data: + value = data[column_name] # type: ignore + if value is None: + return None + + if isinstance(value, np.ndarray): + value = value.tolist() + + if not isinstance(value, list): + msg = f"value is not a list: {value} ({type(value)})" + raise ValueError(msg) + + if item_type is not None: + for v in value: + if not isinstance(v, item_type): + msg = f"list item has item that is not {item_type}: {v} ({type(v)})" + raise TypeError(msg) + return value + + return None + + +def to_int(data: pd.Series, column_name: str | None) -> int: + """Convert and validate a value to an int.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + value = data[column_name] + if isinstance(value, float): + value = int(value) + if not isinstance(value, int): + msg = f"value is not an int: {value} ({type(value)})" + raise ValueError(msg) + else: + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + return int(value) + + +def to_optional_int(data: pd.Series, column_name: str | None) -> int | None: + """Convert and validate a value to an optional int.""" + if column_name is None: + return None + + if column_name in data: + value = data[column_name] + + if value is None: + return None + + if isinstance(value, float): + value = int(value) + if not isinstance(value, int): + msg = f"value is not an int: {value} ({type(value)})" + raise ValueError(msg) + else: + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + return int(value) + + +def to_float(data: pd.Series, column_name: str | None) -> float: + """Convert and validate a value to a float.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + value = data[column_name] + if not isinstance(value, float): + msg = f"value is not a float: {value} ({type(value)})" + raise ValueError(msg) + else: + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + return float(value) + + +def to_optional_float(data: pd.Series, column_name: str | None) -> float | None: + """Convert and validate a value to an optional float.""" + if column_name is None: + return None + + if column_name in data: + value = data[column_name] + if value is None: + return None + if not isinstance(value, float): + msg = f"value is not a float: {value} ({type(value)})" + raise ValueError(msg) + else: + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + return float(value) + + +def to_dict( + data: pd.Series, + column_name: str | None, + key_type: type | None = None, + value_type: type | None = None, +) -> dict: + """Convert and validate a value to a dict.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + value = data[column_name] + if not isinstance(value, dict): + msg = f"value is not a dict: {value} ({type(value)})" + raise ValueError(msg) + + if key_type is not None: + for v in value: + if not isinstance(v, key_type): + msg = f"dict key has item that is not {key_type}: {v} ({type(v)})" + raise TypeError(msg) + + if value_type is not None: + for v in value.values(): + if not isinstance(v, value_type): + msg = ( + f"dict value has item that is not {value_type}: {v} ({type(v)})" + ) + raise TypeError(msg) + return value + + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + +def to_optional_dict( + data: pd.Series, + column_name: str | None, + key_type: type | None = None, + value_type: type | None = None, +) -> dict | None: + """Convert and validate a value to an optional dict.""" + if column_name is None: + return None + + if column_name in data: + value = data[column_name] + if value is None: + return None + if not isinstance(value, dict): + msg = f"value is not a dict: {value} ({type(value)})" + raise TypeError(msg) + + if key_type is not None: + for v in value: + if not isinstance(v, key_type): + msg = f"dict key has item that is not {key_type}: {v} ({type(v)})" + raise TypeError(msg) + + if value_type is not None: + for v in value.values(): + if not isinstance(v, value_type): + msg = ( + f"dict value has item that is not {value_type}: {v} ({type(v)})" + ) + raise TypeError(msg) + + return value + + msg = f"Column {column_name} not found in data" + raise ValueError(msg) diff --git a/func-app/graphrag/query/input/retrieval/__init__.py b/func-app/graphrag/query/input/retrieval/__init__.py new file mode 100644 index 0000000000..75c2f9f095 --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG Orchestration Input Retrieval.""" diff --git a/func-app/graphrag/query/input/retrieval/community_reports.py b/func-app/graphrag/query/input/retrieval/community_reports.py new file mode 100644 index 0000000000..bd4933f1f9 --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/community_reports.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Util functions to retrieve community reports from a collection.""" + +from typing import Any, cast + +import pandas as pd + +from graphrag.model import CommunityReport, Entity + + +def get_candidate_communities( + selected_entities: list[Entity], + community_reports: list[CommunityReport], + include_community_rank: bool = False, + use_community_summary: bool = False, +) -> pd.DataFrame: + """Get all communities that are related to selected entities.""" + selected_community_ids = [ + entity.community_ids for entity in selected_entities if entity.community_ids + ] + selected_community_ids = [ + item for sublist in selected_community_ids for item in sublist + ] + selected_reports = [ + community + for community in community_reports + if community.id in selected_community_ids + ] + return to_community_report_dataframe( + reports=selected_reports, + include_community_rank=include_community_rank, + use_community_summary=use_community_summary, + ) + + +def to_community_report_dataframe( + reports: list[CommunityReport], + include_community_rank: bool = False, + use_community_summary: bool = False, +) -> pd.DataFrame: + """Convert a list of communities to a pandas dataframe.""" + if len(reports) == 0: + return pd.DataFrame() + + # add header + header = ["id", "title"] + attribute_cols = list(reports[0].attributes.keys()) if reports[0].attributes else [] + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + header.append("summary" if use_community_summary else "content") + if include_community_rank: + header.append("rank") + + records = [] + for report in reports: + new_record = [ + report.short_id if report.short_id else "", + report.title, + *[ + str(report.attributes.get(field, "")) + if report.attributes and report.attributes.get(field) + else "" + for field in attribute_cols + ], + ] + new_record.append( + report.summary if use_community_summary else report.full_content + ) + if include_community_rank: + new_record.append(str(report.rank)) + records.append(new_record) + return pd.DataFrame(records, columns=cast(Any, header)) diff --git a/func-app/graphrag/query/input/retrieval/covariates.py b/func-app/graphrag/query/input/retrieval/covariates.py new file mode 100644 index 0000000000..1c45203d01 --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/covariates.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Util functions to retrieve covariates from a collection.""" + +from typing import Any, cast + +import pandas as pd + +from graphrag.model import Covariate, Entity + + +def get_candidate_covariates( + selected_entities: list[Entity], + covariates: list[Covariate], +) -> list[Covariate]: + """Get all covariates that are related to selected entities.""" + selected_entity_names = [entity.title for entity in selected_entities] + return [ + covariate + for covariate in covariates + if covariate.subject_id in selected_entity_names + ] + + +def to_covariate_dataframe(covariates: list[Covariate]) -> pd.DataFrame: + """Convert a list of covariates to a pandas dataframe.""" + if len(covariates) == 0: + return pd.DataFrame() + + # add header + header = ["id", "entity"] + attributes = covariates[0].attributes or {} if len(covariates) > 0 else {} + attribute_cols = list(attributes.keys()) if len(covariates) > 0 else [] + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + records = [] + for covariate in covariates: + new_record = [ + covariate.short_id if covariate.short_id else "", + covariate.subject_id, + ] + for field in attribute_cols: + field_value = ( + str(covariate.attributes.get(field)) + if covariate.attributes and covariate.attributes.get(field) + else "" + ) + new_record.append(field_value) + records.append(new_record) + return pd.DataFrame(records, columns=cast(Any, header)) diff --git a/func-app/graphrag/query/input/retrieval/entities.py b/func-app/graphrag/query/input/retrieval/entities.py new file mode 100644 index 0000000000..5465f9f59e --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/entities.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Util functions to get entities from a collection.""" + +import uuid +from collections.abc import Iterable +from typing import Any, cast + +import pandas as pd + +from graphrag.model import Entity + + +def get_entity_by_key( + entities: Iterable[Entity], key: str, value: str | int +) -> Entity | None: + """Get entity by key.""" + for entity in entities: + if isinstance(value, str) and is_valid_uuid(value): + if getattr(entity, key) == value or getattr(entity, key) == value.replace( + "-", "" + ): + return entity + else: + if getattr(entity, key) == value: + return entity + return None + + +def get_entity_by_name(entities: Iterable[Entity], entity_name: str) -> list[Entity]: + """Get entities by name.""" + return [entity for entity in entities if entity.title == entity_name] + + +def get_entity_by_attribute( + entities: Iterable[Entity], attribute_name: str, attribute_value: Any +) -> list[Entity]: + """Get entities by attribute.""" + return [ + entity + for entity in entities + if entity.attributes + and entity.attributes.get(attribute_name) == attribute_value + ] + + +def to_entity_dataframe( + entities: list[Entity], + include_entity_rank: bool = True, + rank_description: str = "number of relationships", +) -> pd.DataFrame: + """Convert a list of entities to a pandas dataframe.""" + if len(entities) == 0: + return pd.DataFrame() + header = ["id", "entity", "description"] + if include_entity_rank: + header.append(rank_description) + attribute_cols = ( + list(entities[0].attributes.keys()) if entities[0].attributes else [] + ) + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + records = [] + for entity in entities: + new_record = [ + entity.short_id if entity.short_id else "", + entity.title, + entity.description if entity.description else "", + ] + if include_entity_rank: + new_record.append(str(entity.rank)) + + for field in attribute_cols: + field_value = ( + str(entity.attributes.get(field)) + if entity.attributes and entity.attributes.get(field) + else "" + ) + new_record.append(field_value) + records.append(new_record) + return pd.DataFrame(records, columns=cast(Any, header)) + + +def is_valid_uuid(value: str) -> bool: + """Determine if a string is a valid UUID.""" + try: + uuid.UUID(str(value)) + except ValueError: + return False + else: + return True diff --git a/func-app/graphrag/query/input/retrieval/relationships.py b/func-app/graphrag/query/input/retrieval/relationships.py new file mode 100644 index 0000000000..7be258d23c --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/relationships.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Util functions to retrieve relationships from a collection.""" + +import time +from typing import Any, cast + +import pandas as pd + +from common.graph_db_client import GraphDBClient +from graphrag.model import Entity, Relationship + +from graphrag.query.input.loaders.dfs import read_relationships + +def get_relationships_from_graphdb(query:str,selected_entity_names:list[str],graphdb_client: GraphDBClient): + relationships_result=graphdb_client._client.submit( + message=query, + bindings={ + "prop_selected_entity_names": selected_entity_names, + } + ) + time.sleep(5) + print(graphdb_client.result_to_df(relationships_result)) + return read_relationships( + graphdb_client.result_to_df(relationships_result), + short_id_col="human_readable_id" + ) + +def get_in_network_relationships( + selected_entities: list[Entity], + relationships: list[Relationship], + ranking_attribute: str = "rank", + graphdb_client: GraphDBClient|None=None, +) -> list[Relationship]: + """Get all directed relationships between selected entities, sorted by ranking_attribute.""" + selected_entity_names = [entity.title for entity in selected_entities] + if not graphdb_client: + selected_relationships = [ + relationship + for relationship in relationships + if relationship.source in selected_entity_names + and relationship.target in selected_entity_names + ] + else: + selected_relationships = get_relationships_from_graphdb( + query=( + "g.E()" + ".where(inV().has('name',within(prop_selected_entity_names)))" + ".where(outV().has('name',within(prop_selected_entity_names)))" + ), + selected_entity_names=selected_entity_names, + graphdb_client=graphdb_client + ) + if len(selected_relationships) <= 1: + return selected_relationships + + # sort by ranking attribute + return sort_relationships_by_ranking_attribute( + selected_relationships, selected_entities, ranking_attribute + ) + + +def get_out_network_relationships( + selected_entities: list[Entity], + relationships: list[Relationship], + ranking_attribute: str = "rank", + graphdb_client: GraphDBClient|None=None, +) -> list[Relationship]: + """Get relationships from selected entities to other entities that are not within the selected entities, sorted by ranking_attribute.""" + selected_entity_names = [entity.title for entity in selected_entities] + if not graphdb_client: + source_relationships = [ + relationship + for relationship in relationships + if relationship.source in selected_entity_names + and relationship.target not in selected_entity_names + ] + target_relationships = [ + relationship + for relationship in relationships + if relationship.target in selected_entity_names + and relationship.source not in selected_entity_names + ] + selected_relationships = source_relationships + target_relationships + else: + selected_relationships = get_relationships_from_graphdb( + query=( + "g.E().union(" + "__.where(outV().has('name',without(prop_selected_entity_names)))" + ".where(inV().has('name',within(prop_selected_entity_names)))," + "__.where(inV().has('name',without(prop_selected_entity_names)))" + ".where(outV().has('name',within(prop_selected_entity_names)))" + ")" + ), + selected_entity_names= selected_entity_names, + graphdb_client=graphdb_client + ) + return sort_relationships_by_ranking_attribute( + selected_relationships, selected_entities, ranking_attribute + ) + + +def get_candidate_relationships( + selected_entities: list[Entity], + relationships: list[Relationship], +) -> list[Relationship]: + """Get all relationships that are associated with the selected entities.""" + selected_entity_names = [entity.title for entity in selected_entities] + return [ + relationship + for relationship in relationships + if relationship.source in selected_entity_names + or relationship.target in selected_entity_names + ] + + +def get_entities_from_relationships( + relationships: list[Relationship], entities: list[Entity] +) -> list[Entity]: + """Get all entities that are associated with the selected relationships.""" + selected_entity_names = [relationship.source for relationship in relationships] + [ + relationship.target for relationship in relationships + ] + return [entity for entity in entities if entity.title in selected_entity_names] + + +def calculate_relationship_combined_rank( + relationships: list[Relationship], + entities: list[Entity], + ranking_attribute: str = "rank", +) -> list[Relationship]: + """Calculate default rank for a relationship based on the combined rank of source and target entities.""" + entity_mappings = {entity.title: entity for entity in entities} + + for relationship in relationships: + if relationship.attributes is None: + relationship.attributes = {} + source = entity_mappings.get(relationship.source) + target = entity_mappings.get(relationship.target) + source_rank = source.rank if source and source.rank else 0 + target_rank = target.rank if target and target.rank else 0 + relationship.attributes[ranking_attribute] = source_rank + target_rank # type: ignore + return relationships + + +def sort_relationships_by_ranking_attribute( + relationships: list[Relationship], + entities: list[Entity], + ranking_attribute: str = "rank", +) -> list[Relationship]: + """ + Sort relationships by a ranking_attribute. + + If no ranking attribute exists, sort by combined rank of source and target entities. + """ + if len(relationships) == 0: + return relationships + + # sort by ranking attribute + attribute_names = ( + list(relationships[0].attributes.keys()) if relationships[0].attributes else [] + ) + if ranking_attribute in attribute_names: + relationships.sort( + key=lambda x: int(x.attributes[ranking_attribute]) if x.attributes else 0, + reverse=True, + ) + elif ranking_attribute == "weight": + relationships.sort(key=lambda x: x.weight if x.weight else 0.0, reverse=True) + else: + # ranking attribute do not exist, calculate rank = combined ranks of source and target + relationships = calculate_relationship_combined_rank( + relationships, entities, ranking_attribute + ) + relationships.sort( + key=lambda x: int(x.attributes[ranking_attribute]) if x.attributes else 0, + reverse=True, + ) + return relationships + + +def to_relationship_dataframe( + relationships: list[Relationship], include_relationship_weight: bool = True +) -> pd.DataFrame: + """Convert a list of relationships to a pandas dataframe.""" + if len(relationships) == 0: + return pd.DataFrame() + + header = ["id", "source", "target", "description"] + if include_relationship_weight: + header.append("weight") + attribute_cols = ( + list(relationships[0].attributes.keys()) if relationships[0].attributes else [] + ) + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + records = [] + for rel in relationships: + new_record = [ + rel.short_id if rel.short_id else "", + rel.source, + rel.target, + rel.description if rel.description else "", + ] + if include_relationship_weight: + new_record.append(str(rel.weight if rel.weight else "")) + for field in attribute_cols: + field_value = ( + str(rel.attributes.get(field)) + if rel.attributes and rel.attributes.get(field) + else "" + ) + new_record.append(field_value) + records.append(new_record) + return pd.DataFrame(records, columns=cast(Any, header)) diff --git a/func-app/graphrag/query/input/retrieval/text_units.py b/func-app/graphrag/query/input/retrieval/text_units.py new file mode 100644 index 0000000000..a00dc20a0a --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/text_units.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Util functions to retrieve text units from a collection.""" + +from typing import Any, cast + +import pandas as pd + +from graphrag.model import Entity, TextUnit + + +def get_candidate_text_units( + selected_entities: list[Entity], + text_units: list[TextUnit], +) -> pd.DataFrame: + """Get all text units that are associated to selected entities.""" + selected_text_ids = [ + entity.text_unit_ids for entity in selected_entities if entity.text_unit_ids + ] + selected_text_ids = [item for sublist in selected_text_ids for item in sublist] + selected_text_units = [unit for unit in text_units if unit.id in selected_text_ids] + return to_text_unit_dataframe(selected_text_units) + + +def to_text_unit_dataframe(text_units: list[TextUnit]) -> pd.DataFrame: + """Convert a list of text units to a pandas dataframe.""" + if len(text_units) == 0: + return pd.DataFrame() + + # add header + header = ["id", "text"] + attribute_cols = ( + list(text_units[0].attributes.keys()) if text_units[0].attributes else [] + ) + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + records = [] + for unit in text_units: + new_record = [ + unit.short_id, + unit.text, + *[ + str(unit.attributes.get(field, "")) + if unit.attributes and unit.attributes.get(field) + else "" + for field in attribute_cols + ], + ] + records.append(new_record) + return pd.DataFrame(records, columns=cast(Any, header)) diff --git a/func-app/graphrag/query/llm/__init__.py b/func-app/graphrag/query/llm/__init__.py new file mode 100644 index 0000000000..b8f507b138 --- /dev/null +++ b/func-app/graphrag/query/llm/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Orchestration LLM utilities.""" diff --git a/func-app/graphrag/query/llm/base.py b/func-app/graphrag/query/llm/base.py new file mode 100644 index 0000000000..228150af50 --- /dev/null +++ b/func-app/graphrag/query/llm/base.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for LLM and Embedding models.""" + +from abc import ABC, abstractmethod +from typing import Any + + +class BaseLLMCallback: + """Base class for LLM callbacks.""" + + def __init__(self): + self.response = [] + + def on_llm_new_token(self, token: str): + """Handle when a new token is generated.""" + self.response.append(token) + + +class BaseLLM(ABC): + """The Base LLM implementation.""" + + @abstractmethod + def generate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate a response.""" + + @abstractmethod + async def agenerate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate a response asynchronously.""" + + +class BaseTextEmbedding(ABC): + """The text embedding interface.""" + + @abstractmethod + def embed(self, text: str, **kwargs: Any) -> list[float]: + """Embed a text string.""" + + @abstractmethod + async def aembed(self, text: str, **kwargs: Any) -> list[float]: + """Embed a text string asynchronously.""" diff --git a/func-app/graphrag/query/llm/oai/__init__.py b/func-app/graphrag/query/llm/oai/__init__.py new file mode 100644 index 0000000000..cbb257905e --- /dev/null +++ b/func-app/graphrag/query/llm/oai/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG Orchestration OpenAI Wrappers.""" + +from .base import BaseOpenAILLM, OpenAILLMImpl, OpenAITextEmbeddingImpl +from .chat_openai import ChatOpenAI +from .embedding import OpenAIEmbedding +from .openai import OpenAI +from .typing import OPENAI_RETRY_ERROR_TYPES, OpenaiApiType + +__all__ = [ + "OPENAI_RETRY_ERROR_TYPES", + "BaseOpenAILLM", + "ChatOpenAI", + "OpenAI", + "OpenAIEmbedding", + "OpenAILLMImpl", + "OpenAITextEmbeddingImpl", + "OpenaiApiType", +] diff --git a/func-app/graphrag/query/llm/oai/base.py b/func-app/graphrag/query/llm/oai/base.py new file mode 100644 index 0000000000..6181c0b2a5 --- /dev/null +++ b/func-app/graphrag/query/llm/oai/base.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for LLM and Embedding models.""" + +from abc import ABC, abstractmethod +from collections.abc import Callable + +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI + +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.query.llm.oai.typing import OpenaiApiType +from graphrag.query.progress import ConsoleStatusReporter, StatusReporter + + +class BaseOpenAILLM(ABC): + """The Base OpenAI LLM implementation.""" + + _async_client: AsyncOpenAI | AsyncAzureOpenAI + _sync_client: OpenAI | AzureOpenAI + + def __init__(self): + self._create_openai_client() + + @abstractmethod + def _create_openai_client(self): + """Create a new synchronous and asynchronous OpenAI client instance.""" + + def set_clients( + self, + sync_client: OpenAI | AzureOpenAI, + async_client: AsyncOpenAI | AsyncAzureOpenAI, + ): + """ + Set the synchronous and asynchronous clients used for making API requests. + + Args: + sync_client (OpenAI | AzureOpenAI): The sync client object. + async_client (AsyncOpenAI | AsyncAzureOpenAI): The async client object. + """ + self._sync_client = sync_client + self._async_client = async_client + + @property + def async_client(self) -> AsyncOpenAI | AsyncAzureOpenAI | None: + """ + Get the asynchronous client used for making API requests. + + Returns + ------- + AsyncOpenAI | AsyncAzureOpenAI: The async client object. + """ + return self._async_client + + @property + def sync_client(self) -> OpenAI | AzureOpenAI | None: + """ + Get the synchronous client used for making API requests. + + Returns + ------- + AsyncOpenAI | AsyncAzureOpenAI: The async client object. + """ + return self._sync_client + + @async_client.setter + def async_client(self, client: AsyncOpenAI | AsyncAzureOpenAI): + """ + Set the asynchronous client used for making API requests. + + Args: + client (AsyncOpenAI | AsyncAzureOpenAI): The async client object. + """ + self._async_client = client + + @sync_client.setter + def sync_client(self, client: OpenAI | AzureOpenAI): + """ + Set the synchronous client used for making API requests. + + Args: + client (OpenAI | AzureOpenAI): The sync client object. + """ + self._sync_client = client + + +class OpenAILLMImpl(BaseOpenAILLM): + """Orchestration OpenAI LLM Implementation.""" + + _reporter: StatusReporter = ConsoleStatusReporter() + + def __init__( + self, + api_key: str | None = None, + azure_ad_token_provider: Callable | None = None, + deployment_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + api_type: OpenaiApiType = OpenaiApiType.OpenAI, + organization: str | None = None, + max_retries: int = 10, + request_timeout: float = 180.0, + reporter: StatusReporter | None = None, + ): + self.api_key = api_key + self.azure_ad_token_provider = azure_ad_token_provider + self.deployment_name = deployment_name + self.api_base = api_base + self.api_version = api_version + self.api_type = api_type + self.organization = organization + self.max_retries = max_retries + self.request_timeout = request_timeout + self.reporter = reporter or ConsoleStatusReporter() + + try: + # Create OpenAI sync and async clients + super().__init__() + except Exception as e: + self._reporter.error( + message="Failed to create OpenAI client", + details={self.__class__.__name__: str(e)}, + ) + raise + + def _create_openai_client(self): + """Create a new OpenAI client instance.""" + if self.api_type == OpenaiApiType.AzureOpenAI: + if self.api_base is None: + msg = "api_base is required for Azure OpenAI" + raise ValueError(msg) + + sync_client = AzureOpenAI( + api_key=self.api_key, + azure_ad_token_provider=self.azure_ad_token_provider, + organization=self.organization, + # Azure-Specifics + api_version=self.api_version, + azure_endpoint=self.api_base, + azure_deployment=self.deployment_name, + # Retry Configuration + timeout=self.request_timeout, + max_retries=self.max_retries, + ) + + async_client = AsyncAzureOpenAI( + api_key=self.api_key, + azure_ad_token_provider=self.azure_ad_token_provider, + organization=self.organization, + # Azure-Specifics + api_version=self.api_version, + azure_endpoint=self.api_base, + azure_deployment=self.deployment_name, + # Retry Configuration + timeout=self.request_timeout, + max_retries=self.max_retries, + ) + self.set_clients(sync_client=sync_client, async_client=async_client) + + else: + sync_client = OpenAI( + api_key=self.api_key, + base_url=self.api_base, + organization=self.organization, + # Retry Configuration + timeout=self.request_timeout, + max_retries=self.max_retries, + ) + + async_client = AsyncOpenAI( + api_key=self.api_key, + base_url=self.api_base, + organization=self.organization, + # Retry Configuration + timeout=self.request_timeout, + max_retries=self.max_retries, + ) + self.set_clients(sync_client=sync_client, async_client=async_client) + + +class OpenAITextEmbeddingImpl(BaseTextEmbedding): + """Orchestration OpenAI Text Embedding Implementation.""" + + _reporter: StatusReporter | None = None + + def _create_openai_client(self, api_type: OpenaiApiType): + """Create a new synchronous and asynchronous OpenAI client instance.""" diff --git a/func-app/graphrag/query/llm/oai/chat_openai.py b/func-app/graphrag/query/llm/oai/chat_openai.py new file mode 100644 index 0000000000..92a9755b10 --- /dev/null +++ b/func-app/graphrag/query/llm/oai/chat_openai.py @@ -0,0 +1,206 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Chat-based OpenAI LLM implementation.""" + +from collections.abc import Callable +from typing import Any + +from tenacity import ( + AsyncRetrying, + RetryError, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from graphrag.query.llm.base import BaseLLM, BaseLLMCallback +from graphrag.query.llm.oai.base import OpenAILLMImpl +from graphrag.query.llm.oai.typing import ( + OPENAI_RETRY_ERROR_TYPES, + OpenaiApiType, +) +from graphrag.query.progress import StatusReporter + +_MODEL_REQUIRED_MSG = "model is required" + + +class ChatOpenAI(BaseLLM, OpenAILLMImpl): + """Wrapper for OpenAI ChatCompletion models.""" + + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + azure_ad_token_provider: Callable | None = None, + deployment_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + api_type: OpenaiApiType = OpenaiApiType.OpenAI, + organization: str | None = None, + max_retries: int = 10, + request_timeout: float = 180.0, + retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore + reporter: StatusReporter | None = None, + ): + OpenAILLMImpl.__init__( + self=self, + api_key=api_key, + azure_ad_token_provider=azure_ad_token_provider, + deployment_name=deployment_name, + api_base=api_base, + api_version=api_version, + api_type=api_type, # type: ignore + organization=organization, + max_retries=max_retries, + request_timeout=request_timeout, + reporter=reporter, + ) + self.model = model + self.retry_error_types = retry_error_types + + def generate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate text.""" + try: + retryer = Retrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + for attempt in retryer: + with attempt: + return self._generate( + messages=messages, + streaming=streaming, + callbacks=callbacks, + **kwargs, + ) + except RetryError as e: + self._reporter.error( + message="Error at generate()", details={self.__class__.__name__: str(e)} + ) + return "" + else: + # TODO: why not just throw in this case? + return "" + + async def agenerate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate text asynchronously.""" + try: + retryer = AsyncRetrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), # type: ignore + ) + async for attempt in retryer: + with attempt: + return await self._agenerate( + messages=messages, + streaming=streaming, + callbacks=callbacks, + **kwargs, + ) + except RetryError as e: + self._reporter.error(f"Error at agenerate(): {e}") + return "" + else: + # TODO: why not just throw in this case? + return "" + + def _generate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + model = self.model + if not model: + raise ValueError(_MODEL_REQUIRED_MSG) + response = self.sync_client.chat.completions.create( # type: ignore + model=model, + messages=messages, # type: ignore + stream=streaming, + **kwargs, + ) # type: ignore + if streaming: + full_response = "" + while True: + try: + chunk = response.__next__() # type: ignore + if not chunk or not chunk.choices: + continue + + delta = ( + chunk.choices[0].delta.content + if chunk.choices[0].delta and chunk.choices[0].delta.content + else "" + ) # type: ignore + + full_response += delta + if callbacks: + for callback in callbacks: + callback.on_llm_new_token(delta) + if chunk.choices[0].finish_reason == "stop": # type: ignore + break + except StopIteration: + break + return full_response + return response.choices[0].message.content or "" # type: ignore + + async def _agenerate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + model = self.model + if not model: + raise ValueError(_MODEL_REQUIRED_MSG) + response = await self.async_client.chat.completions.create( # type: ignore + model=model, + messages=messages, # type: ignore + stream=streaming, + **kwargs, + ) + if streaming: + full_response = "" + while True: + try: + chunk = await response.__anext__() # type: ignore + if not chunk or not chunk.choices: + continue + + delta = ( + chunk.choices[0].delta.content + if chunk.choices[0].delta and chunk.choices[0].delta.content + else "" + ) # type: ignore + + full_response += delta + if callbacks: + for callback in callbacks: + callback.on_llm_new_token(delta) + if chunk.choices[0].finish_reason == "stop": # type: ignore + break + except StopIteration: + break + return full_response + + return response.choices[0].message.content or "" # type: ignore diff --git a/func-app/graphrag/query/llm/oai/embedding.py b/func-app/graphrag/query/llm/oai/embedding.py new file mode 100644 index 0000000000..f40372dbce --- /dev/null +++ b/func-app/graphrag/query/llm/oai/embedding.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""OpenAI Embedding model implementation.""" + +import asyncio +from collections.abc import Callable +from typing import Any + +import numpy as np +import tiktoken +from tenacity import ( + AsyncRetrying, + RetryError, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.query.llm.oai.base import OpenAILLMImpl +from graphrag.query.llm.oai.typing import ( + OPENAI_RETRY_ERROR_TYPES, + OpenaiApiType, +) +from graphrag.query.llm.text_utils import chunk_text +from graphrag.query.progress import StatusReporter + + +class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl): + """Wrapper for OpenAI Embedding models.""" + + def __init__( + self, + api_key: str | None = None, + azure_ad_token_provider: Callable | None = None, + model: str = "text-embedding-3-small", + deployment_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + api_type: OpenaiApiType = OpenaiApiType.OpenAI, + organization: str | None = None, + encoding_name: str = "cl100k_base", + max_tokens: int = 8191, + max_retries: int = 10, + request_timeout: float = 180.0, + retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore + reporter: StatusReporter | None = None, + ): + OpenAILLMImpl.__init__( + self=self, + api_key=api_key, + azure_ad_token_provider=azure_ad_token_provider, + deployment_name=deployment_name, + api_base=api_base, + api_version=api_version, + api_type=api_type, # type: ignore + organization=organization, + max_retries=max_retries, + request_timeout=request_timeout, + reporter=reporter, + ) + + self.model = model + self.encoding_name = encoding_name + self.max_tokens = max_tokens + self.token_encoder = tiktoken.get_encoding(self.encoding_name) + self.retry_error_types = retry_error_types + + def embed(self, text: str, **kwargs: Any) -> list[float]: + """ + Embed text using OpenAI Embedding's sync function. + + For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average. + Please refer to: https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb + """ + token_chunks = chunk_text( + text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens + ) + chunk_embeddings = [] + chunk_lens = [] + for chunk in token_chunks: + try: + embedding, chunk_len = self._embed_with_retry(chunk, **kwargs) + chunk_embeddings.append(embedding) + chunk_lens.append(chunk_len) + # TODO: catch a more specific exception + except Exception as e: # noqa BLE001 + self._reporter.error( + message="Error embedding chunk", + details={self.__class__.__name__: str(e)}, + ) + + continue + chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) + chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) + return chunk_embeddings.tolist() + + async def aembed(self, text: str, **kwargs: Any) -> list[float]: + """ + Embed text using OpenAI Embedding's async function. + + For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average. + """ + token_chunks = chunk_text( + text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens + ) + chunk_embeddings = [] + chunk_lens = [] + embedding_results = await asyncio.gather(*[ + self._aembed_with_retry(chunk, **kwargs) for chunk in token_chunks + ]) + embedding_results = [result for result in embedding_results if result[0]] + chunk_embeddings = [result[0] for result in embedding_results] + chunk_lens = [result[1] for result in embedding_results] + chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) # type: ignore + chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) + return chunk_embeddings.tolist() + + def _embed_with_retry( + self, text: str | tuple, **kwargs: Any + ) -> tuple[list[float], int]: + try: + retryer = Retrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + for attempt in retryer: + with attempt: + embedding = ( + self.sync_client.embeddings.create( # type: ignore + input=text, + model=self.model, + **kwargs, # type: ignore + ) + .data[0] + .embedding + or [] + ) + return (embedding, len(text)) + except RetryError as e: + self._reporter.error( + message="Error at embed_with_retry()", + details={self.__class__.__name__: str(e)}, + ) + return ([], 0) + else: + # TODO: why not just throw in this case? + return ([], 0) + + async def _aembed_with_retry( + self, text: str | tuple, **kwargs: Any + ) -> tuple[list[float], int]: + try: + retryer = AsyncRetrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + async for attempt in retryer: + with attempt: + embedding = ( + await self.async_client.embeddings.create( # type: ignore + input=text, + model=self.model, + **kwargs, # type: ignore + ) + ).data[0].embedding or [] + return (embedding, len(text)) + except RetryError as e: + self._reporter.error( + message="Error at embed_with_retry()", + details={self.__class__.__name__: str(e)}, + ) + return ([], 0) + else: + # TODO: why not just throw in this case? + return ([], 0) diff --git a/func-app/graphrag/query/llm/oai/openai.py b/func-app/graphrag/query/llm/oai/openai.py new file mode 100644 index 0000000000..76bb5fe52c --- /dev/null +++ b/func-app/graphrag/query/llm/oai/openai.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""OpenAI Wrappers for Orchestration.""" + +import logging +from typing import Any + +from tenacity import ( + AsyncRetrying, + RetryError, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from graphrag.query.llm.base import BaseLLMCallback +from graphrag.query.llm.oai.base import OpenAILLMImpl +from graphrag.query.llm.oai.typing import ( + OPENAI_RETRY_ERROR_TYPES, + OpenaiApiType, +) + +log = logging.getLogger(__name__) + + +class OpenAI(OpenAILLMImpl): + """Wrapper for OpenAI Completion models.""" + + def __init__( + self, + api_key: str, + model: str, + deployment_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + api_type: OpenaiApiType = OpenaiApiType.OpenAI, + organization: str | None = None, + max_retries: int = 10, + retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore + ): + self.api_key = api_key + self.model = model + self.deployment_name = deployment_name + self.api_base = api_base + self.api_version = api_version + self.api_type = api_type + self.organization = organization + self.max_retries = max_retries + self.retry_error_types = retry_error_types + + def generate( + self, + messages: str | list[str], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate text.""" + try: + retryer = Retrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + for attempt in retryer: + with attempt: + return self._generate( + messages=messages, + streaming=streaming, + callbacks=callbacks, + **kwargs, + ) + except RetryError: + log.exception("RetryError at generate(): %s") + return "" + else: + # TODO: why not just throw in this case? + return "" + + async def agenerate( + self, + messages: str | list[str], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate Text Asynchronously.""" + try: + retryer = AsyncRetrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + async for attempt in retryer: + with attempt: + return await self._agenerate( + messages=messages, + streaming=streaming, + callbacks=callbacks, + **kwargs, + ) + except RetryError: + log.exception("Error at agenerate()") + return "" + else: + # TODO: why not just throw in this case? + return "" + + def _generate( + self, + messages: str | list[str], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + response = self.sync_client.chat.completions.create( # type: ignore + model=self.model, + messages=messages, # type: ignore + stream=streaming, + **kwargs, + ) # type: ignore + if streaming: + full_response = "" + while True: + try: + chunk = response.__next__() # type: ignore + if not chunk or not chunk.choices: + continue + + delta = ( + chunk.choices[0].delta.content + if chunk.choices[0].delta and chunk.choices[0].delta.content + else "" + ) # type: ignore + + full_response += delta + if callbacks: + for callback in callbacks: + callback.on_llm_new_token(delta) + if chunk.choices[0].finish_reason == "stop": # type: ignore + break + except StopIteration: + break + return full_response + return response.choices[0].message.content or "" # type: ignore + + async def _agenerate( + self, + messages: str | list[str], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + response = await self.async_client.chat.completions.create( # type: ignore + model=self.model, + messages=messages, # type: ignore + stream=streaming, + **kwargs, + ) + if streaming: + full_response = "" + while True: + try: + chunk = await response.__anext__() # type: ignore + if not chunk or not chunk.choices: + continue + + delta = ( + chunk.choices[0].delta.content + if chunk.choices[0].delta and chunk.choices[0].delta.content + else "" + ) # type: ignore + + full_response += delta + if callbacks: + for callback in callbacks: + callback.on_llm_new_token(delta) + if chunk.choices[0].finish_reason == "stop": # type: ignore + break + except StopIteration: + break + return full_response + return response.choices[0].message.content or "" # type: ignore diff --git a/func-app/graphrag/query/llm/oai/typing.py b/func-app/graphrag/query/llm/oai/typing.py new file mode 100644 index 0000000000..399a82f699 --- /dev/null +++ b/func-app/graphrag/query/llm/oai/typing.py @@ -0,0 +1,23 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""OpenAI wrapper options.""" + +from enum import Enum +from typing import Any, cast + +import openai + +OPENAI_RETRY_ERROR_TYPES = ( + # TODO: update these when we update to OpenAI 1+ library + cast(Any, openai).RateLimitError, + cast(Any, openai).APIConnectionError, + # TODO: replace with comparable OpenAI 1+ error +) + + +class OpenaiApiType(str, Enum): + """The OpenAI Flavor.""" + + OpenAI = "openai" + AzureOpenAI = "azure" diff --git a/func-app/graphrag/query/llm/text_utils.py b/func-app/graphrag/query/llm/text_utils.py new file mode 100644 index 0000000000..d60e630488 --- /dev/null +++ b/func-app/graphrag/query/llm/text_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Text Utilities for LLM.""" + +from collections.abc import Iterator +from itertools import islice + +import tiktoken + + +def num_tokens(text: str, token_encoder: tiktoken.Encoding | None = None) -> int: + """Return the number of tokens in the given text.""" + if token_encoder is None: + token_encoder = tiktoken.get_encoding("cl100k_base") + return len(token_encoder.encode(text)) # type: ignore + + +def batched(iterable: Iterator, n: int): + """ + Batch data into tuples of length n. The last batch may be shorter. + + Taken from Python's cookbook: https://docs.python.org/3/library/itertools.html#itertools.batched + """ + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + value_error = "n must be at least one" + raise ValueError(value_error) + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch + + +def chunk_text( + text: str, max_tokens: int, token_encoder: tiktoken.Encoding | None = None +): + """Chunk text by token length.""" + if token_encoder is None: + token_encoder = tiktoken.get_encoding("cl100k_base") + tokens = token_encoder.encode(text) # type: ignore + chunk_iterator = batched(iter(tokens), max_tokens) + yield from chunk_iterator diff --git a/func-app/graphrag/query/progress.py b/func-app/graphrag/query/progress.py new file mode 100644 index 0000000000..ad5bcee734 --- /dev/null +++ b/func-app/graphrag/query/progress.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Status Reporter for orchestration.""" + +from abc import ABCMeta, abstractmethod +from typing import Any + + +class StatusReporter(metaclass=ABCMeta): + """Provides a way to report status updates from the pipeline.""" + + @abstractmethod + def error(self, message: str, details: dict[str, Any] | None = None): + """Report an error.""" + + @abstractmethod + def warning(self, message: str, details: dict[str, Any] | None = None): + """Report a warning.""" + + @abstractmethod + def log(self, message: str, details: dict[str, Any] | None = None): + """Report a log.""" + + +class ConsoleStatusReporter(StatusReporter): + """A reporter that writes to a console.""" + + def error(self, message: str, details: dict[str, Any] | None = None): + """Report an error.""" + print(message, details) # noqa T201 + + def warning(self, message: str, details: dict[str, Any] | None = None): + """Report a warning.""" + _print_warning(message) + + def log(self, message: str, details: dict[str, Any] | None = None): + """Report a log.""" + print(message, details) # noqa T201 + + +def _print_warning(skk): + print(f"\033[93m {skk}\033[00m") # noqa T201 diff --git a/func-app/graphrag/query/question_gen/__init__.py b/func-app/graphrag/query/question_gen/__init__.py new file mode 100644 index 0000000000..d7329277c2 --- /dev/null +++ b/func-app/graphrag/query/question_gen/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Question Generation Module.""" diff --git a/func-app/graphrag/query/question_gen/base.py b/func-app/graphrag/query/question_gen/base.py new file mode 100644 index 0000000000..959b63d791 --- /dev/null +++ b/func-app/graphrag/query/question_gen/base.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for generating questions based on previously asked questions and most recent context data.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +import tiktoken + +from graphrag.query.context_builder.builders import ( + GlobalContextBuilder, + LocalContextBuilder, +) +from graphrag.query.llm.base import BaseLLM + + +@dataclass +class QuestionResult: + """A Structured Question Result.""" + + response: list[str] + context_data: str | dict[str, Any] + completion_time: float + llm_calls: int + prompt_tokens: int + + +class BaseQuestionGen(ABC): + """The Base Question Gen implementation.""" + + def __init__( + self, + llm: BaseLLM, + context_builder: GlobalContextBuilder | LocalContextBuilder, + token_encoder: tiktoken.Encoding | None = None, + llm_params: dict[str, Any] | None = None, + context_builder_params: dict[str, Any] | None = None, + ): + self.llm = llm + self.context_builder = context_builder + self.token_encoder = token_encoder + self.llm_params = llm_params or {} + self.context_builder_params = context_builder_params or {} + + @abstractmethod + def generate( + self, + question_history: list[str], + context_data: str | None, + question_count: int, + **kwargs, + ) -> QuestionResult: + """Generate questions.""" + + @abstractmethod + async def agenerate( + self, + question_history: list[str], + context_data: str | None, + question_count: int, + **kwargs, + ) -> QuestionResult: + """Generate questions asynchronously.""" diff --git a/func-app/graphrag/query/question_gen/local_gen.py b/func-app/graphrag/query/question_gen/local_gen.py new file mode 100644 index 0000000000..ca703a66e3 --- /dev/null +++ b/func-app/graphrag/query/question_gen/local_gen.py @@ -0,0 +1,194 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Local question generation.""" + +import logging +import time +from typing import Any + +import tiktoken + +from graphrag.query.context_builder.builders import LocalContextBuilder +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.llm.base import BaseLLM, BaseLLMCallback +from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult +from graphrag.query.question_gen.system_prompt import QUESTION_SYSTEM_PROMPT + +log = logging.getLogger(__name__) + + +class LocalQuestionGen(BaseQuestionGen): + """Search orchestration for global search mode.""" + + def __init__( + self, + llm: BaseLLM, + context_builder: LocalContextBuilder, + token_encoder: tiktoken.Encoding | None = None, + system_prompt: str = QUESTION_SYSTEM_PROMPT, + callbacks: list[BaseLLMCallback] | None = None, + llm_params: dict[str, Any] | None = None, + context_builder_params: dict[str, Any] | None = None, + ): + super().__init__( + llm=llm, + context_builder=context_builder, + token_encoder=token_encoder, + llm_params=llm_params, + context_builder_params=context_builder_params, + ) + self.system_prompt = system_prompt + self.callbacks = callbacks + + async def agenerate( + self, + question_history: list[str], + context_data: str | None, + question_count: int, + **kwargs, + ) -> QuestionResult: + """ + Generate a question based on the question history and context data. + + If context data is not provided, it will be generated by the local context builder + """ + start_time = time.time() + + if len(question_history) == 0: + question_text = "" + conversation_history = None + else: + # construct current query and conversation history + question_text = question_history[-1] + history = [ + {"role": "user", "content": query} for query in question_history[:-1] + ] + conversation_history = ConversationHistory.from_list(history) + + if context_data is None: + # generate context data based on the question history + context_data, context_records = self.context_builder.build_context( + query=question_text, + conversation_history=conversation_history, + **kwargs, + **self.context_builder_params, + ) # type: ignore + else: + context_records = {"context_data": context_data} + log.info("GENERATE QUESTION: %s. LAST QUESTION: %s", start_time, question_text) + system_prompt = "" + try: + system_prompt = self.system_prompt.format( + context_data=context_data, question_count=question_count + ) + question_messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question_text}, + ] + + response = await self.llm.agenerate( + messages=question_messages, + streaming=True, + callbacks=self.callbacks, + **self.llm_params, + ) + + return QuestionResult( + response=response.split("\n"), + context_data={ + "question_context": question_text, + **context_records, + }, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(system_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in generating question") + return QuestionResult( + response=[], + context_data=context_records, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(system_prompt, self.token_encoder), + ) + + def generate( + self, + question_history: list[str], + context_data: str | None, + question_count: int, + **kwargs, + ) -> QuestionResult: + """ + Generate a question based on the question history and context data. + + If context data is not provided, it will be generated by the local context builder + """ + start_time = time.time() + if len(question_history) == 0: + question_text = "" + conversation_history = None + else: + # construct current query and conversation history + question_text = question_history[-1] + history = [ + {"role": "user", "content": query} for query in question_history[:-1] + ] + conversation_history = ConversationHistory.from_list(history) + + if context_data is None: + # generate context data based on the question history + context_data, context_records = self.context_builder.build_context( + query=question_text, + conversation_history=conversation_history, + **kwargs, + **self.context_builder_params, + ) # type: ignore + else: + context_records = {"context_data": context_data} + log.info( + "GENERATE QUESTION: %s. QUESTION HISTORY: %s", start_time, question_text + ) + system_prompt = "" + try: + system_prompt = self.system_prompt.format( + context_data=context_data, question_count=question_count + ) + question_messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question_text}, + ] + + response = self.llm.generate( + messages=question_messages, + streaming=True, + callbacks=self.callbacks, + **self.llm_params, + ) + + return QuestionResult( + response=response.split("\n"), + context_data={ + "question_context": question_text, + **context_records, + }, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(system_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in generating questions") + return QuestionResult( + response=[], + context_data=context_records, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(system_prompt, self.token_encoder), + ) diff --git a/func-app/graphrag/query/question_gen/system_prompt.py b/func-app/graphrag/query/question_gen/system_prompt.py new file mode 100644 index 0000000000..904ede2435 --- /dev/null +++ b/func-app/graphrag/query/question_gen/system_prompt.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Question Generation system prompts.""" + +QUESTION_SYSTEM_PROMPT = """ +---Role--- + +You are a helpful assistant generating a bulleted list of {question_count} questions about data in the tables provided. + + +---Data tables--- + +{context_data} + + +---Goal--- + +Given a series of example questions provided by the user, generate a bulleted list of {question_count} candidates for the next question. Use - marks as bullet points. + +These candidate questions should represent the most important or urgent information content or themes in the data tables. + +The candidate questions should be answerable using the data tables provided, but should not mention any specific data fields or data tables in the question text. + +If the user's questions reference several named entities, then each candidate question should reference all named entities. + +---Example questions--- +""" diff --git a/func-app/graphrag/query/structured_search/__init__.py b/func-app/graphrag/query/structured_search/__init__.py new file mode 100644 index 0000000000..b41baaf340 --- /dev/null +++ b/func-app/graphrag/query/structured_search/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Structured Search package.""" diff --git a/func-app/graphrag/query/structured_search/base.py b/func-app/graphrag/query/structured_search/base.py new file mode 100644 index 0000000000..6dd02485f8 --- /dev/null +++ b/func-app/graphrag/query/structured_search/base.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for search algos.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +import pandas as pd +import tiktoken + +from graphrag.query.context_builder.builders import ( + GlobalContextBuilder, + LocalContextBuilder, +) +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.llm.base import BaseLLM + + +@dataclass +class SearchResult: + """A Structured Search Result.""" + + response: str | dict[str, Any] | list[dict[str, Any]] + context_data: str | list[pd.DataFrame] | dict[str, pd.DataFrame] + # actual text strings that are in the context window, built from context_data + context_text: str | list[str] | dict[str, str] + completion_time: float + llm_calls: int + prompt_tokens: int + + +class BaseSearch(ABC): + """The Base Search implementation.""" + + def __init__( + self, + llm: BaseLLM, + context_builder: GlobalContextBuilder | LocalContextBuilder, + token_encoder: tiktoken.Encoding | None = None, + llm_params: dict[str, Any] | None = None, + context_builder_params: dict[str, Any] | None = None, + ): + self.llm = llm + self.context_builder = context_builder + self.token_encoder = token_encoder + self.llm_params = llm_params or {} + self.context_builder_params = context_builder_params or {} + + @abstractmethod + def search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Search for the given query.""" + + @abstractmethod + async def asearch( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Search for the given query asynchronously.""" diff --git a/func-app/graphrag/query/structured_search/global_search/__init__.py b/func-app/graphrag/query/structured_search/global_search/__init__.py new file mode 100644 index 0000000000..ba73b60900 --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GlobalSearch module.""" diff --git a/func-app/graphrag/query/structured_search/global_search/callbacks.py b/func-app/graphrag/query/structured_search/global_search/callbacks.py new file mode 100644 index 0000000000..f48bb79b82 --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/callbacks.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GlobalSearch LLM Callbacks.""" + +from graphrag.query.llm.base import BaseLLMCallback +from graphrag.query.structured_search.base import SearchResult + + +class GlobalSearchLLMCallback(BaseLLMCallback): + """GlobalSearch LLM Callbacks.""" + + def __init__(self): + super().__init__() + self.map_response_contexts = [] + self.map_response_outputs = [] + + def on_map_response_start(self, map_response_contexts: list[str]): + """Handle the start of map response.""" + self.map_response_contexts = map_response_contexts + + def on_map_response_end(self, map_response_outputs: list[SearchResult]): + """Handle the end of map response.""" + self.map_response_outputs = map_response_outputs diff --git a/func-app/graphrag/query/structured_search/global_search/community_context.py b/func-app/graphrag/query/structured_search/global_search/community_context.py new file mode 100644 index 0000000000..d63320c85b --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/community_context.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Contains algorithms to build context data for global search prompt.""" + +from typing import Any + +import pandas as pd +import tiktoken + +from graphrag.model import CommunityReport, Entity +from graphrag.query.context_builder.community_context import ( + build_community_context, +) +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.structured_search.base import GlobalContextBuilder + + +class GlobalCommunityContext(GlobalContextBuilder): + """GlobalSearch community context builder.""" + + def __init__( + self, + community_reports: list[CommunityReport], + entities: list[Entity] | None = None, + token_encoder: tiktoken.Encoding | None = None, + random_state: int = 86, + ): + self.community_reports = community_reports + self.entities = entities + self.token_encoder = token_encoder + self.random_state = random_state + + def build_context( + self, + conversation_history: ConversationHistory | None = None, + use_community_summary: bool = True, + column_delimiter: str = "|", + shuffle_data: bool = True, + include_community_rank: bool = False, + min_community_rank: int = 0, + community_rank_name: str = "rank", + include_community_weight: bool = True, + community_weight_name: str = "occurrence", + normalize_community_weight: bool = True, + max_tokens: int = 8000, + context_name: str = "Reports", + conversation_history_user_turns_only: bool = True, + conversation_history_max_turns: int | None = 5, + **kwargs: Any, + ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: + """Prepare batches of community report data table as context data for global search.""" + conversation_history_context = "" + final_context_data = {} + if conversation_history: + # build conversation history context + ( + conversation_history_context, + conversation_history_context_data, + ) = conversation_history.build_context( + include_user_turns_only=conversation_history_user_turns_only, + max_qa_turns=conversation_history_max_turns, + column_delimiter=column_delimiter, + max_tokens=max_tokens, + recency_bias=False, + ) + if conversation_history_context != "": + final_context_data = conversation_history_context_data + + community_context, community_context_data = build_community_context( + community_reports=self.community_reports, + entities=self.entities, + token_encoder=self.token_encoder, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + shuffle_data=shuffle_data, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + community_rank_name=community_rank_name, + include_community_weight=include_community_weight, + community_weight_name=community_weight_name, + normalize_community_weight=normalize_community_weight, + max_tokens=max_tokens, + single_batch=False, + context_name=context_name, + random_state=self.random_state, + ) + if isinstance(community_context, list): + final_context = [ + f"{conversation_history_context}\n\n{context}" + for context in community_context + ] + else: + final_context = f"{conversation_history_context}\n\n{community_context}" + + final_context_data.update(community_context_data) + return (final_context, final_context_data) diff --git a/func-app/graphrag/query/structured_search/global_search/map_system_prompt.py b/func-app/graphrag/query/structured_search/global_search/map_system_prompt.py new file mode 100644 index 0000000000..db1a649df3 --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/map_system_prompt.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""System prompts for global search.""" + +MAP_SYSTEM_PROMPT = """ +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + + +---Data tables--- + +{context_data} + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} +""" diff --git a/func-app/graphrag/query/structured_search/global_search/reduce_system_prompt.py b/func-app/graphrag/query/structured_search/global_search/reduce_system_prompt.py new file mode 100644 index 0000000000..701717817c --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/reduce_system_prompt.py @@ -0,0 +1,88 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Global Search system prompts.""" + +REDUCE_SYSTEM_PROMPT = """ +---Role--- + +You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + + +---Analyst Reports--- + +{report_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. +""" + +NO_DATA_ANSWER = ( + "I am sorry but I am unable to answer this question given the provided data." +) + +GENERAL_KNOWLEDGE_INSTRUCTION = """ +The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example: +"This is an example sentence supported by real-world knowledge [LLM: verify]." +""" diff --git a/func-app/graphrag/query/structured_search/global_search/search.py b/func-app/graphrag/query/structured_search/global_search/search.py new file mode 100644 index 0000000000..12dc45fe7a --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/search.py @@ -0,0 +1,359 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The GlobalSearch Implementation.""" + +import asyncio +import json +import logging +import time +from dataclasses import dataclass +from typing import Any + +import pandas as pd +import tiktoken + +from graphrag.llm.openai.utils import try_parse_json_object +from graphrag.query.context_builder.builders import GlobalContextBuilder +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.llm.base import BaseLLM +from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.structured_search.base import BaseSearch, SearchResult +from graphrag.query.structured_search.global_search.callbacks import ( + GlobalSearchLLMCallback, +) +from graphrag.query.structured_search.global_search.map_system_prompt import ( + MAP_SYSTEM_PROMPT, +) +from graphrag.query.structured_search.global_search.reduce_system_prompt import ( + GENERAL_KNOWLEDGE_INSTRUCTION, + NO_DATA_ANSWER, + REDUCE_SYSTEM_PROMPT, +) + +DEFAULT_MAP_LLM_PARAMS = { + "max_tokens": 1000, + "temperature": 0.0, +} + +DEFAULT_REDUCE_LLM_PARAMS = { + "max_tokens": 2000, + "temperature": 0.0, +} + +log = logging.getLogger(__name__) + + +@dataclass +class GlobalSearchResult(SearchResult): + """A GlobalSearch result.""" + + map_responses: list[SearchResult] + reduce_context_data: str | list[pd.DataFrame] | dict[str, pd.DataFrame] + reduce_context_text: str | list[str] | dict[str, str] + + +class GlobalSearch(BaseSearch): + """Search orchestration for global search mode.""" + + def __init__( + self, + llm: BaseLLM, + context_builder: GlobalContextBuilder, + token_encoder: tiktoken.Encoding | None = None, + map_system_prompt: str = MAP_SYSTEM_PROMPT, + reduce_system_prompt: str = REDUCE_SYSTEM_PROMPT, + response_type: str = "multiple paragraphs", + allow_general_knowledge: bool = False, + general_knowledge_inclusion_prompt: str = GENERAL_KNOWLEDGE_INSTRUCTION, + json_mode: bool = True, + callbacks: list[GlobalSearchLLMCallback] | None = None, + max_data_tokens: int = 8000, + map_llm_params: dict[str, Any] = DEFAULT_MAP_LLM_PARAMS, + reduce_llm_params: dict[str, Any] = DEFAULT_REDUCE_LLM_PARAMS, + context_builder_params: dict[str, Any] | None = None, + concurrent_coroutines: int = 32, + ): + super().__init__( + llm=llm, + context_builder=context_builder, + token_encoder=token_encoder, + context_builder_params=context_builder_params, + ) + self.map_system_prompt = map_system_prompt + self.reduce_system_prompt = reduce_system_prompt + self.response_type = response_type + self.allow_general_knowledge = allow_general_knowledge + self.general_knowledge_inclusion_prompt = general_knowledge_inclusion_prompt + self.callbacks = callbacks + self.max_data_tokens = max_data_tokens + + self.map_llm_params = map_llm_params + self.reduce_llm_params = reduce_llm_params + if json_mode: + self.map_llm_params["response_format"] = {"type": "json_object"} + else: + # remove response_format key if json_mode is False + self.map_llm_params.pop("response_format", None) + + self.semaphore = asyncio.Semaphore(concurrent_coroutines) + + async def asearch( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs: Any, + ) -> GlobalSearchResult: + """ + Perform a global search. + + Global search mode includes two steps: + + - Step 1: Run parallel LLM calls on communities' short summaries to generate answer for each batch + - Step 2: Combine the answers from step 2 to generate the final answer + """ + # Step 1: Generate answers for each batch of community short summaries + start_time = time.time() + context_chunks, context_records = self.context_builder.build_context( + conversation_history=conversation_history, **self.context_builder_params + ) + + if self.callbacks: + for callback in self.callbacks: + callback.on_map_response_start(context_chunks) # type: ignore + map_responses = await asyncio.gather(*[ + self._map_response_single_batch( + context_data=data, query=query, **self.map_llm_params + ) + for data in context_chunks + ]) + if self.callbacks: + for callback in self.callbacks: + callback.on_map_response_end(map_responses) + map_llm_calls = sum(response.llm_calls for response in map_responses) + map_prompt_tokens = sum(response.prompt_tokens for response in map_responses) + + # Step 2: Combine the intermediate answers from step 2 to generate the final answer + reduce_response = await self._reduce_response( + map_responses=map_responses, + query=query, + **self.reduce_llm_params, + ) + + return GlobalSearchResult( + response=reduce_response.response, + context_data=context_records, + context_text=context_chunks, + map_responses=map_responses, + reduce_context_data=reduce_response.context_data, + reduce_context_text=reduce_response.context_text, + completion_time=time.time() - start_time, + llm_calls=map_llm_calls + reduce_response.llm_calls, + prompt_tokens=map_prompt_tokens + reduce_response.prompt_tokens, + ) + + def search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs: Any, + ) -> GlobalSearchResult: + """Perform a global search synchronously.""" + return asyncio.run(self.asearch(query, conversation_history)) + + async def _map_response_single_batch( + self, + context_data: str, + query: str, + **llm_kwargs, + ) -> SearchResult: + """Generate answer for a single chunk of community reports.""" + start_time = time.time() + search_prompt = "" + try: + search_prompt = self.map_system_prompt.format(context_data=context_data) + search_messages = [ + {"role": "system", "content": search_prompt}, + {"role": "user", "content": query}, + ] + async with self.semaphore: + search_response = await self.llm.agenerate( + messages=search_messages, streaming=False, **llm_kwargs + ) + log.info("Map response: %s", search_response) + try: + # parse search response json + processed_response = self.parse_search_response(search_response) + except ValueError: + # Clean up and retry parse + try: + # parse search response json + processed_response = self.parse_search_response(search_response) + except ValueError: + log.warning( + "Warning: Error parsing search response json - skipping this batch" + ) + processed_response = [] + + return SearchResult( + response=processed_response, + context_data=context_data, + context_text=context_data, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in _map_response_single_batch") + return SearchResult( + response=[{"answer": "", "score": 0}], + context_data=context_data, + context_text=context_data, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + def parse_search_response(self, search_response: str) -> list[dict[str, Any]]: + """Parse the search response json and return a list of key points. + + Parameters + ---------- + search_response: str + The search response json string + + Returns + ------- + list[dict[str, Any]] + A list of key points, each key point is a dictionary with "answer" and "score" keys + """ + search_response, _j = try_parse_json_object(search_response) + if _j == {}: + return [{"answer": "", "score": 0}] + + parsed_elements = json.loads(search_response).get("points") + if not parsed_elements or not isinstance(parsed_elements, list): + return [{"answer": "", "score": 0}] + + return [ + { + "answer": element["description"], + "score": int(element["score"]), + } + for element in parsed_elements + if "description" in element and "score" in element + ] + + async def _reduce_response( + self, + map_responses: list[SearchResult], + query: str, + **llm_kwargs, + ) -> SearchResult: + """Combine all intermediate responses from single batches into a final answer to the user query.""" + text_data = "" + search_prompt = "" + start_time = time.time() + try: + # collect all key points into a single list to prepare for sorting + key_points = [] + for index, response in enumerate(map_responses): + if not isinstance(response.response, list): + continue + for element in response.response: + if not isinstance(element, dict): + continue + if "answer" not in element or "score" not in element: + continue + key_points.append({ + "analyst": index, + "answer": element["answer"], + "score": element["score"], + }) + + # filter response with score = 0 and rank responses by descending order of score + filtered_key_points = [ + point + for point in key_points + if point["score"] > 0 # type: ignore + ] + + if len(filtered_key_points) == 0 and not self.allow_general_knowledge: + # return no data answer if no key points are found + log.warning( + "Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations." + ) + return SearchResult( + response=NO_DATA_ANSWER, + context_data="", + context_text="", + completion_time=time.time() - start_time, + llm_calls=0, + prompt_tokens=0, + ) + + filtered_key_points = sorted( + filtered_key_points, + key=lambda x: x["score"], # type: ignore + reverse=True, # type: ignore + ) + + data = [] + total_tokens = 0 + for point in filtered_key_points: + formatted_response_data = [] + formatted_response_data.append( + f'----Analyst {point["analyst"] + 1}----' + ) + formatted_response_data.append( + f'Importance Score: {point["score"]}' # type: ignore + ) + formatted_response_data.append(point["answer"]) # type: ignore + formatted_response_text = "\n".join(formatted_response_data) + if ( + total_tokens + + num_tokens(formatted_response_text, self.token_encoder) + > self.max_data_tokens + ): + break + data.append(formatted_response_text) + total_tokens += num_tokens(formatted_response_text, self.token_encoder) + text_data = "\n\n".join(data) + + search_prompt = self.reduce_system_prompt.format( + report_data=text_data, response_type=self.response_type + ) + if self.allow_general_knowledge: + search_prompt += "\n" + self.general_knowledge_inclusion_prompt + search_messages = [ + {"role": "system", "content": search_prompt}, + {"role": "user", "content": query}, + ] + + search_response = await self.llm.agenerate( + search_messages, + streaming=True, + callbacks=self.callbacks, # type: ignore + **llm_kwargs, # type: ignore + ) + return SearchResult( + response=search_response, + context_data=text_data, + context_text=text_data, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + except Exception: + log.exception("Exception in reduce_response") + return SearchResult( + response="", + context_data=text_data, + context_text=text_data, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) diff --git a/func-app/graphrag/query/structured_search/local_search/__init__.py b/func-app/graphrag/query/structured_search/local_search/__init__.py new file mode 100644 index 0000000000..8b8b1e790e --- /dev/null +++ b/func-app/graphrag/query/structured_search/local_search/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The LocalSearch package.""" diff --git a/func-app/graphrag/query/structured_search/local_search/mixed_context.py b/func-app/graphrag/query/structured_search/local_search/mixed_context.py new file mode 100644 index 0000000000..af4c63d55b --- /dev/null +++ b/func-app/graphrag/query/structured_search/local_search/mixed_context.py @@ -0,0 +1,533 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Algorithms to build context data for local search prompt.""" + +import logging +from typing import Any + +import pandas as pd +from common.graph_db_client import GraphDBClient +from graphrag.config.models.graphdb_config import GraphDBConfig +import tiktoken + +from graphrag.model import ( + CommunityReport, + Covariate, + Entity, + Relationship, + TextUnit, +) +from graphrag.query.context_builder.community_context import ( + build_community_context, +) +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.context_builder.entity_extraction import ( + EntityVectorStoreKey, + map_query_to_entities, +) +from graphrag.query.context_builder.local_context import ( + build_covariates_context, + build_entity_context, + build_relationship_context, + get_candidate_context, +) +from graphrag.query.context_builder.source_context import ( + build_text_unit_context, + count_relationships, +) +from graphrag.query.input.retrieval.community_reports import ( + get_candidate_communities, +) +from graphrag.query.input.retrieval.text_units import get_candidate_text_units +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.structured_search.base import LocalContextBuilder +from graphrag.vector_stores import BaseVectorStore +from graphrag.vector_stores.kusto import KustoVectorStore + +log = logging.getLogger(__name__) + + +class LocalSearchMixedContext(LocalContextBuilder): + """Build data context for local search prompt combining community reports and entity/relationship/covariate tables.""" + + def __init__( + self, + entities: list[Entity], + entity_text_embeddings: BaseVectorStore, + text_embedder: BaseTextEmbedding, + text_units: list[TextUnit] | None = None, + community_reports: list[CommunityReport] | None = None, + relationships: list[Relationship] | None = None, + covariates: dict[str, list[Covariate]] | None = None, + token_encoder: tiktoken.Encoding | None = None, + embedding_vectorstore_key: str = EntityVectorStoreKey.ID, + is_optimized_search: bool = False, + use_kusto_community_reports: bool = False, + graphdb_config: GraphDBConfig|None = None, + context_id:str = None, + ): + if community_reports is None: + community_reports = [] + if relationships is None: + relationships = [] + if covariates is None: + covariates = {} + if text_units is None: + text_units = [] + self.entities = {entity.id: entity for entity in entities} + self.community_reports = { + community.id: community for community in community_reports + } + self.text_units = {unit.id: unit for unit in text_units} + self.relationships = { + relationship.id: relationship for relationship in relationships + } + self.covariates = covariates + self.entity_text_embeddings = entity_text_embeddings + self.text_embedder = text_embedder + self.token_encoder = token_encoder + self.embedding_vectorstore_key = embedding_vectorstore_key + self.is_optimized_search = is_optimized_search + self.use_kusto_community_reports = use_kusto_community_reports + self.graphdb_config = graphdb_config + self.context_id = context_id + + def filter_by_entity_keys(self, entity_keys: list[int] | list[str]): + """Filter entity text embeddings by entity keys.""" + self.entity_text_embeddings.filter_by_id(entity_keys) + + def build_context( + self, + query: str, + conversation_history: ConversationHistory | None = None, + include_entity_names: list[str] | None = None, + exclude_entity_names: list[str] | None = None, + conversation_history_max_turns: int | None = 5, + conversation_history_user_turns_only: bool = True, + max_tokens: int = 8000, + text_unit_prop: float = 0.5, + community_prop: float = 0.25, + top_k_mapped_entities: int = 10, + top_k_relationships: int = 10, + include_community_rank: bool = False, + include_entity_rank: bool = False, + rank_description: str = "number of relationships", + include_relationship_weight: bool = False, + relationship_ranking_attribute: str = "rank", + return_candidate_context: bool = False, + use_community_summary: bool = False, + min_community_rank: int = 0, + community_context_name: str = "Reports", + column_delimiter: str = "|", + is_optimized_search: bool = False, + **kwargs: dict[str, Any], + ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: + """ + Build data context for local search prompt. + + Build a context by combining community reports and entity/relationship/covariate tables, and text units using a predefined ratio set by summary_prop. + """ + if include_entity_names is None: + include_entity_names = [] + if exclude_entity_names is None: + exclude_entity_names = [] + if community_prop + text_unit_prop > 1: + value_error = ( + "The sum of community_prop and text_unit_prop should not exceed 1." + ) + raise ValueError(value_error) + + # map user query to entities + # if there is conversation history, attached the previous user questions to the current query + if conversation_history: + pre_user_questions = "\n".join( + conversation_history.get_user_turns(conversation_history_max_turns) + ) + query = f"{query}\n{pre_user_questions}" + + selected_entities = map_query_to_entities( + query=query, + text_embedding_vectorstore=self.entity_text_embeddings, + text_embedder=self.text_embedder, + all_entities=list(self.entities.values()), + embedding_vectorstore_key=self.embedding_vectorstore_key, + include_entity_names=include_entity_names, + exclude_entity_names=exclude_entity_names, + k=top_k_mapped_entities, + oversample_scaler=2, + ) + + print("Selected entities titles: ", [entity.title for entity in selected_entities]) + + # build context + final_context = list[str]() + final_context_data = dict[str, pd.DataFrame]() + + if conversation_history: + # build conversation history context + ( + conversation_history_context, + conversation_history_context_data, + ) = conversation_history.build_context( + include_user_turns_only=conversation_history_user_turns_only, + max_qa_turns=conversation_history_max_turns, + column_delimiter=column_delimiter, + max_tokens=max_tokens, + recency_bias=False, + ) + if conversation_history_context.strip() != "": + final_context.append(conversation_history_context) + final_context_data = conversation_history_context_data + max_tokens = max_tokens - num_tokens( + conversation_history_context, self.token_encoder + ) + + if not is_optimized_search: + community_tokens = max(int(max_tokens * community_prop), 0) + community_context, community_context_data = self._build_community_context( + selected_entities=selected_entities, + max_tokens=community_tokens, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + return_candidate_context=return_candidate_context, + context_name=community_context_name, + is_optimized_search=is_optimized_search + ) + if community_context.strip() != "": + final_context.append(community_context) + final_context_data = {**final_context_data, **community_context_data} + + # build local (i.e. entity-relationship-covariate) context + local_prop = 1 - community_prop - text_unit_prop + local_tokens = max(int(max_tokens * local_prop), 0) + local_context, local_context_data = self._build_local_context( + selected_entities=selected_entities, + max_tokens=local_tokens, + include_entity_rank=include_entity_rank, + rank_description=rank_description, + include_relationship_weight=include_relationship_weight, + top_k_relationships=top_k_relationships, + relationship_ranking_attribute=relationship_ranking_attribute, + return_candidate_context=return_candidate_context, + column_delimiter=column_delimiter, + is_optimized_search=is_optimized_search + ) + if local_context.strip() != "": + final_context.append(str(local_context)) + final_context_data = {**final_context_data, **local_context_data} + if not self.is_optimized_search: + # build text unit context + text_unit_tokens = max(int(max_tokens * text_unit_prop), 0) + text_unit_context, text_unit_context_data = self._build_text_unit_context( + selected_entities=selected_entities, + max_tokens=text_unit_tokens, + return_candidate_context=return_candidate_context, + ) + if text_unit_context.strip() != "": + final_context.append(text_unit_context) + final_context_data = {**final_context_data, **text_unit_context_data} + + return ("\n\n".join(final_context), final_context_data) + + def _build_community_context( + self, + selected_entities: list[Entity], + max_tokens: int = 4000, + use_community_summary: bool = False, + column_delimiter: str = "|", + include_community_rank: bool = False, + min_community_rank: int = 0, + return_candidate_context: bool = False, + context_name: str = "Reports", + is_optimized_search: bool = False, + ) -> tuple[str, dict[str, pd.DataFrame]]: + """Add community data to the context window until it hits the max_tokens limit.""" + if len(selected_entities) == 0 or (len(self.community_reports) == 0 and not self.use_kusto_community_reports): + return ("", {context_name.lower(): pd.DataFrame()}) + + community_matches = {} + for entity in selected_entities: + # increase count of the community that this entity belongs to + if entity.community_ids: + for community_id in entity.community_ids: + community_matches[community_id] = ( + community_matches.get(community_id, 0) + 1 + ) + + selected_communities = [] + if self.use_kusto_community_reports: + selected_communities = self.entity_text_embeddings.get_extracted_reports( + community_ids=list(community_matches.keys()) + ) + else: + selected_communities = [ + self.community_reports[community_id] + for community_id in community_matches + if community_id in self.community_reports + ] + + # sort communities by number of matched entities and rank + for community in selected_communities: + if community.attributes is None: + community.attributes = {} + community.attributes["matches"] = community_matches[community.id] + selected_communities.sort( + key=lambda x: (x.attributes["matches"], x.rank), # type: ignore + reverse=True, # type: ignore + ) + for community in selected_communities: + del community.attributes["matches"] # type: ignore + context_data = {} + context_data["reports"] = selected_communities + context_text = "" + if not is_optimized_search: + context_text, context_data = build_community_context( + community_reports=selected_communities, + token_encoder=self.token_encoder, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + shuffle_data=False, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + max_tokens=max_tokens, + single_batch=True, + context_name=context_name, + ) + if isinstance(context_text, list) and len(context_text) > 0: + context_text = "\n\n".join(context_text) + + if return_candidate_context: + candidate_context_data = get_candidate_communities( + selected_entities=selected_entities, + community_reports=list(self.community_reports.values()), + use_community_summary=use_community_summary, + include_community_rank=include_community_rank, + ) + context_key = context_name.lower() + if context_key not in context_data: + context_data[context_key] = candidate_context_data + context_data[context_key]["in_context"] = False + else: + if ( + "id" in candidate_context_data.columns + and "id" in context_data[context_key].columns + ): + candidate_context_data["in_context"] = candidate_context_data[ + "id" + ].isin( # cspell:disable-line + context_data[context_key]["id"] + ) + context_data[context_key] = candidate_context_data + else: + context_data[context_key]["in_context"] = True + return (str(context_text), context_data) + + def _build_text_unit_context( + self, + selected_entities: list[Entity], + max_tokens: int = 8000, + return_candidate_context: bool = False, + column_delimiter: str = "|", + context_name: str = "Sources", + ) -> tuple[str, dict[str, pd.DataFrame]]: + """Rank matching text units and add them to the context window until it hits the max_tokens limit.""" + if len(selected_entities) == 0 or len(self.text_units) == 0: + return ("", {context_name.lower(): pd.DataFrame()}) + + selected_text_units = list[TextUnit]() + # for each matching text unit, rank first by the order of the entities that match it, then by the number of matching relationships + # that the text unit has with the matching entities + for index, entity in enumerate(selected_entities): + if entity.text_unit_ids: + for text_id in entity.text_unit_ids: + if ( + text_id not in [unit.id for unit in selected_text_units] + and text_id in self.text_units + ): + selected_unit = self.text_units[text_id] + num_relationships = count_relationships( + selected_unit, entity, self.relationships + ) + if selected_unit.attributes is None: + selected_unit.attributes = {} + selected_unit.attributes["entity_order"] = index + selected_unit.attributes["num_relationships"] = ( + num_relationships + ) + selected_text_units.append(selected_unit) + + # sort selected text units by ascending order of entity order and descending order of number of relationships + selected_text_units.sort( + key=lambda x: ( + x.attributes["entity_order"], # type: ignore + -x.attributes["num_relationships"], # type: ignore + ) + ) + + for unit in selected_text_units: + del unit.attributes["entity_order"] # type: ignore + del unit.attributes["num_relationships"] # type: ignore + + context_text, context_data = build_text_unit_context( + text_units=selected_text_units, + token_encoder=self.token_encoder, + max_tokens=max_tokens, + shuffle_data=False, + context_name=context_name, + column_delimiter=column_delimiter, + ) + + if return_candidate_context: + candidate_context_data = get_candidate_text_units( + selected_entities=selected_entities, + text_units=list(self.text_units.values()), + ) + context_key = context_name.lower() + if context_key not in context_data: + context_data[context_key] = candidate_context_data + context_data[context_key]["in_context"] = False + else: + if ( + "id" in candidate_context_data.columns + and "id" in context_data[context_key].columns + ): + candidate_context_data["in_context"] = candidate_context_data[ + "id" + ].isin( # cspell:disable-line + context_data[context_key]["id"] + ) + context_data[context_key] = candidate_context_data + else: + context_data[context_key]["in_context"] = True + return (str(context_text), context_data) + + def _build_local_context( + self, + selected_entities: list[Entity], + max_tokens: int = 8000, + include_entity_rank: bool = False, + rank_description: str = "relationship count", + include_relationship_weight: bool = False, + top_k_relationships: int = 10, + relationship_ranking_attribute: str = "rank", + return_candidate_context: bool = False, + column_delimiter: str = "|", + is_optimized_search: bool = False + ) -> tuple[str, dict[str, pd.DataFrame]]: + """Build data context for local search prompt combining entity/relationship/covariate tables.""" + # build entity context + entity_context, entity_context_data = build_entity_context( + selected_entities=selected_entities, + token_encoder=self.token_encoder, + max_tokens=max_tokens, + column_delimiter=column_delimiter, + include_entity_rank=include_entity_rank, + rank_description=rank_description, + context_name="Entities", + is_optimized_search=is_optimized_search, + ) + entity_tokens = num_tokens(entity_context, self.token_encoder) + + # build relationship-covariate context + added_entities = [] + final_context = [] + final_context_data = {} + + # gradually add entities and associated metadata to the context until we reach limit + graphdb_client=GraphDBClient(self.graphdb_config,self.context_id) if (self.graphdb_config and self.graphdb_config.enabled) else None + for entity in selected_entities: + current_context = [] + current_context_data = {} + added_entities.append(entity) + + # build relationship context + ( + relationship_context, + relationship_context_data, + ) = build_relationship_context( + selected_entities=added_entities, + relationships=list(self.relationships.values()), + token_encoder=self.token_encoder, + max_tokens=max_tokens, + column_delimiter=column_delimiter, + top_k_relationships=top_k_relationships, + include_relationship_weight=include_relationship_weight, + relationship_ranking_attribute=relationship_ranking_attribute, + context_name="Relationships", + is_optimized_search=is_optimized_search, + graphdb_client=graphdb_client, + ) + current_context.append(relationship_context) + current_context_data["relationships"] = relationship_context_data + total_tokens = entity_tokens + num_tokens( + relationship_context, self.token_encoder + ) + + + # build covariate context + for covariate in self.covariates: + covariate_context, covariate_context_data = build_covariates_context( + selected_entities=added_entities, + covariates=self.covariates[covariate], + token_encoder=self.token_encoder, + max_tokens=max_tokens, + column_delimiter=column_delimiter, + context_name=covariate, + is_optimized_search=is_optimized_search + ) + total_tokens += num_tokens(covariate_context, self.token_encoder) + current_context.append(covariate_context) + current_context_data[covariate.lower()] = covariate_context_data + + if total_tokens > max_tokens: + log.info("Reached token limit - reverting to previous context state") + break + + final_context = current_context + final_context_data = current_context_data + + # attach entity context to final context + if graphdb_client: + graphdb_client._client.close() + final_context_text = entity_context + "\n\n" + "\n\n".join(final_context) + final_context_data["entities"] = entity_context_data + + if return_candidate_context: + # we return all the candidate entities/relationships/covariates (not only those that were fitted into the context window) + # and add a tag to indicate which records were included in the context window + candidate_context_data = get_candidate_context( + selected_entities=selected_entities, + entities=list(self.entities.values()), + relationships=list(self.relationships.values()), + covariates=self.covariates, + include_entity_rank=include_entity_rank, + entity_rank_description=rank_description, + include_relationship_weight=include_relationship_weight, + ) + for key in candidate_context_data: + candidate_df = candidate_context_data[key] + if key not in final_context_data: + final_context_data[key] = candidate_df + final_context_data[key]["in_context"] = False + else: + in_context_df = final_context_data[key] + + if "id" in in_context_df.columns and "id" in candidate_df.columns: + candidate_df["in_context"] = candidate_df[ + "id" + ].isin( # cspell:disable-line + in_context_df["id"] + ) + final_context_data[key] = candidate_df + else: + final_context_data[key]["in_context"] = True + + else: + for key in final_context_data: + final_context_data[key]["in_context"] = True + return (final_context_text, final_context_data) diff --git a/func-app/graphrag/query/structured_search/local_search/search.py b/func-app/graphrag/query/structured_search/local_search/search.py new file mode 100644 index 0000000000..597b511222 --- /dev/null +++ b/func-app/graphrag/query/structured_search/local_search/search.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LocalSearch implementation.""" + +import logging +import time +from typing import Any + +import tiktoken + +from graphrag.query.context_builder.builders import LocalContextBuilder +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.llm.base import BaseLLM, BaseLLMCallback +from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.structured_search.base import BaseSearch, SearchResult +from graphrag.query.structured_search.local_search.system_prompt import ( + LOCAL_SEARCH_SYSTEM_PROMPT, +) + +DEFAULT_LLM_PARAMS = { + "max_tokens": 1500, + "temperature": 0.0, +} + +log = logging.getLogger(__name__) + + +class LocalSearch(BaseSearch): + """Search orchestration for local search mode.""" + + def __init__( + self, + llm: BaseLLM, + context_builder: LocalContextBuilder, + token_encoder: tiktoken.Encoding | None = None, + system_prompt: str = LOCAL_SEARCH_SYSTEM_PROMPT, + response_type: str = "multiple paragraphs", + callbacks: list[BaseLLMCallback] | None = None, + llm_params: dict[str, Any] = DEFAULT_LLM_PARAMS, + context_builder_params: dict | None = None, + ): + super().__init__( + llm=llm, + context_builder=context_builder, + token_encoder=token_encoder, + llm_params=llm_params, + context_builder_params=context_builder_params or {}, + ) + self.system_prompt = system_prompt + self.callbacks = callbacks + self.response_type = response_type + + async def asearch( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Build local search context that fits a single context window and generate answer for the user query.""" + start_time = time.time() + search_prompt = "" + + context_text, context_records = self.context_builder.build_context( + query=query, + conversation_history=conversation_history, + **kwargs, + **self.context_builder_params, + ) + log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query) + try: + search_prompt = self.system_prompt.format( + context_data=context_text, response_type=self.response_type + ) + search_messages = [ + {"role": "system", "content": search_prompt}, + {"role": "user", "content": query}, + ] + + response = await self.llm.agenerate( + messages=search_messages, + streaming=True, + callbacks=self.callbacks, + **self.llm_params, + ) + + return SearchResult( + response=response, + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in _asearch") + return SearchResult( + response="", + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + def search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Build local search context that fits a single context window and generate answer for the user question.""" + + start_time = time.time() + search_prompt = "" + context_text, context_records = self.context_builder.build_context( + query=query, + conversation_history=conversation_history, + **kwargs, + **self.context_builder_params, + ) + log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query) + try: + search_prompt = self.system_prompt.format( + context_data=context_text, response_type=self.response_type + ) + search_messages = [ + {"role": "system", "content": search_prompt}, + {"role": "user", "content": query}, + ] + + response = self.llm.generate( + messages=search_messages, + streaming=True, + callbacks=self.callbacks, + **self.llm_params, + ) + + return SearchResult( + response=response, + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in _map_response_single_batch") + return SearchResult( + response="", + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + + def optimized_search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Build local search context data.""" + start_time = time.time() + search_prompt = "" + context_text, context_records = self.context_builder.build_context( + query=query, + conversation_history=conversation_history, + is_optimized_search = self.optimized_search, + **kwargs, + **self.context_builder_params, + ) + log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query) + try: + return SearchResult( + response="", + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in _map_response_single_batch") + return SearchResult( + response="", + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) diff --git a/func-app/graphrag/query/structured_search/local_search/system_prompt.py b/func-app/graphrag/query/structured_search/local_search/system_prompt.py new file mode 100644 index 0000000000..70b1d12fc3 --- /dev/null +++ b/func-app/graphrag/query/structured_search/local_search/system_prompt.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Local search system prompts.""" + +LOCAL_SEARCH_SYSTEM_PROMPT = """ +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]." + +where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + + +---Data tables--- + +{context_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]." + +where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. +""" diff --git a/func-app/graphrag/vector_stores/__init__.py b/func-app/graphrag/vector_stores/__init__.py new file mode 100644 index 0000000000..d4c11760aa --- /dev/null +++ b/func-app/graphrag/vector_stores/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing vector-storage implementations.""" + +from .azure_ai_search import AzureAISearch +from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult +from .lancedb import LanceDBVectorStore +from .typing import VectorStoreFactory, VectorStoreType + +__all__ = [ + "AzureAISearch", + "BaseVectorStore", + "LanceDBVectorStore", + "VectorStoreDocument", + "VectorStoreFactory", + "VectorStoreSearchResult", + "VectorStoreType", +] diff --git a/func-app/graphrag/vector_stores/azure_ai_search.py b/func-app/graphrag/vector_stores/azure_ai_search.py new file mode 100644 index 0000000000..9a53c9a5b3 --- /dev/null +++ b/func-app/graphrag/vector_stores/azure_ai_search.py @@ -0,0 +1,225 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the Azure AI Search vector store implementation.""" + +import json +from typing import Any + +from azure.core.credentials import AzureKeyCredential +from azure.identity import DefaultAzureCredential +from azure.search.documents import SearchClient +from azure.search.documents.indexes import SearchIndexClient +from azure.search.documents.indexes.models import ( + HnswAlgorithmConfiguration, + HnswParameters, + SearchableField, + SearchField, + SearchFieldDataType, + SearchIndex, + SimpleField, + VectorSearch, + VectorSearchAlgorithmMetric, + VectorSearchProfile, +) +from azure.search.documents.models import VectorizedQuery + +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity +from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit + +from .base import ( + DEFAULT_VECTOR_SIZE, + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + + +class AzureAISearch(BaseVectorStore): + """The Azure AI Search vector storage implementation.""" + + index_client: SearchIndexClient + + def connect(self, **kwargs: Any) -> Any: + """Connect to the AzureAI vector store.""" + url = kwargs.get("url", None) + api_key = kwargs.get("api_key", None) + audience = kwargs.get("audience", None) + self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE) + + self.vector_search_profile_name = kwargs.get( + "vector_search_profile_name", "vectorSearchProfile" + ) + + if url: + audience_arg = {"audience": audience} if audience else {} + self.db_connection = SearchClient( + endpoint=url, + index_name=self.collection_name, + credential=AzureKeyCredential(api_key) + if api_key + else DefaultAzureCredential(), + **audience_arg, + ) + self.index_client = SearchIndexClient( + endpoint=url, + credential=AzureKeyCredential(api_key) + if api_key + else DefaultAzureCredential(), + **audience_arg, + ) + else: + not_supported_error = "AAISearchDBClient is not supported on local host." + raise ValueError(not_supported_error) + + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + """Load documents into the Azure AI Search index.""" + if overwrite: + if self.collection_name in self.index_client.list_index_names(): + self.index_client.delete_index(self.collection_name) + + # Configure the vector search profile + vector_search = VectorSearch( + algorithms=[ + HnswAlgorithmConfiguration( + name="HnswAlg", + parameters=HnswParameters( + metric=VectorSearchAlgorithmMetric.COSINE + ), + ) + ], + profiles=[ + VectorSearchProfile( + name=self.vector_search_profile_name, + algorithm_configuration_name="HnswAlg", + ) + ], + ) + + index = SearchIndex( + name=self.collection_name, + fields=[ + SimpleField( + name="id", + type=SearchFieldDataType.String, + key=True, + ), + SearchField( + name="vector", + type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + searchable=True, + vector_search_dimensions=self.vector_size, + vector_search_profile_name=self.vector_search_profile_name, + ), + SearchableField(name="text", type=SearchFieldDataType.String), + SimpleField( + name="attributes", + type=SearchFieldDataType.String, + ), + ], + vector_search=vector_search, + ) + + self.index_client.create_or_update_index( + index, + ) + + batch = [ + { + "id": doc.id, + "vector": doc.vector, + "text": doc.text, + "attributes": json.dumps(doc.attributes), + } + for doc in documents + if doc.vector is not None + ] + + if batch and len(batch) > 0: + self.db_connection.upload_documents(batch) + + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by a list of ids.""" + if include_ids is None or len(include_ids) == 0: + self.query_filter = None + # Returning to keep consistency with other methods, but not needed + return self.query_filter + + # More info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function + # search.in is faster that joined and/or conditions + id_filter = ",".join([f"{id!s}" for id in include_ids]) + self.query_filter = f"search.in(id, '{id_filter}', ',')" + + # Returning to keep consistency with other methods, but not needed + # TODO: Refactor on a future PR + return self.query_filter + + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a vector-based similarity search.""" + vectorized_query = VectorizedQuery( + vector=query_embedding, k_nearest_neighbors=k, fields="vector" + ) + + response = self.db_connection.search( + vector_queries=[vectorized_query], + ) + + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=doc.get("id", ""), + text=doc.get("text", ""), + vector=doc.get("vector", []), + attributes=(json.loads(doc.get("attributes", "{}"))), + ), + # Cosine similarity between 0.333 and 1.000 + # https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results + score=doc["@search.score"], + ) + for doc in response + ] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a text-based similarity search.""" + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector( + query_embedding=query_embedding, k=k + ) + return [] + + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + raise NotImplementedError("Loading entities is not supported for Azure AI Search") + + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + raise NotImplementedError("Extracting entities is not supported for Azure AI Search") + + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + raise NotImplementedError("Loading reports is not supported for Azure AI Search") + + def load_text_units(self, units: list[TextUnit], overwrite: bool = True) -> None: + raise NotImplementedError("load_text_units(): Unsupported for this vector store.") + + def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: + raise NotImplementedError("Extracting reports is not supported for Azure AI Search") + + def setup_entities(self) -> None: + raise NotImplementedError("Setting up entities is not supported for Azure AI Search") + + def setup_reports(self) -> None: + raise NotImplementedError("Setting up reports is not supported for Azure AI Search") + + def setup_text_units(self) -> None: + raise NotImplementedError("setup_text_units(): Unsupported for this vector store.") + + def unload_entities(self) -> None: + raise NotImplementedError("unload_entities(): Unsupported for this vector store.") \ No newline at end of file diff --git a/func-app/graphrag/vector_stores/base.py b/func-app/graphrag/vector_stores/base.py new file mode 100644 index 0000000000..f143bc77f5 --- /dev/null +++ b/func-app/graphrag/vector_stores/base.py @@ -0,0 +1,130 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for vector stores.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity +from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit + +DEFAULT_VECTOR_SIZE: int = 1536 + + +@dataclass +class VectorStoreDocument: + """A document that is stored in vector storage.""" + + id: str | int + """unique id for the document""" + + text: str | None + vector: list[float] | None + + attributes: dict[str, Any] = field(default_factory=dict) + """store any additional metadata, e.g. title, date ranges, etc""" + + +@dataclass +class VectorStoreSearchResult: + """A vector storage search result.""" + + document: VectorStoreDocument + """Document that was found.""" + + score: float + """Similarity score between 0 and 1. Higher is more similar.""" + + +class BaseVectorStore(ABC): + """The base class for vector storage data-access classes.""" + + def __init__( + self, + collection_name: str, + vector_name: str, + reports_name: str, + text_units_name: str, + db_connection: Any | None = None, + document_collection: Any | None = None, + query_filter: Any | None = None, + **kwargs: Any, + ): + self.collection_name = collection_name + self.vector_name = vector_name + self.reports_name = reports_name + self.text_units_name = text_units_name + self.db_connection = db_connection + self.document_collection = document_collection + self.query_filter = query_filter + self.kwargs = kwargs + + @abstractmethod + def connect(self, **kwargs: Any) -> None: + """Connect to vector storage.""" + + @abstractmethod + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + """Load documents into the vector-store.""" + + @abstractmethod + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform ANN search by vector.""" + + @abstractmethod + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform ANN search by text.""" + + @abstractmethod + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by id.""" + + @abstractmethod + def get_extracted_entities( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + """From a query, build a subtable of entities which is only matching entities.""" + + @abstractmethod + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + """Load entities into the vector-store.""" + + @abstractmethod + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + """Load reports into the vector-store.""" + + @abstractmethod + def get_extracted_reports( + self, community_ids: list[int], **kwargs: Any + ) -> list[CommunityReport]: + """Get reports for a given list of community ids.""" + + @abstractmethod + def setup_entities(self) -> None: + """Setup the entities in the vector-store.""" + + @abstractmethod + def setup_reports(self) -> None: + """Setup the reports in the vector-store.""" + + @abstractmethod + def setup_text_units(self) -> None: + """Setup the reports in the vector-store.""" + + @abstractmethod + def load_text_units(self, units: list[TextUnit], overwrite: bool = True) -> None: + """Load reports into the vector-store.""" + + @abstractmethod + def unload_entities(self) -> None: + """Remove context from the databases.""" \ No newline at end of file diff --git a/func-app/graphrag/vector_stores/kusto.py b/func-app/graphrag/vector_stores/kusto.py new file mode 100644 index 0000000000..273de5ea96 --- /dev/null +++ b/func-app/graphrag/vector_stores/kusto.py @@ -0,0 +1,308 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Azure Kusto vector storage implementation package.""" +import os +import typing +from azure.kusto.data import KustoClient, KustoConnectionStringBuilder +from azure.kusto.data.helpers import dataframe_from_result_table +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity +from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit + +import pandas as pd +from pathlib import Path + +import json +from typing import Any, List, cast + +from graphrag.query.input.loaders.utils import ( + to_list, + to_optional_dict, + to_optional_float, + to_optional_int, + to_optional_list, + to_optional_str, + to_str, +) + +from .base import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + + +class KustoVectorStore(BaseVectorStore): + """The Azure Kusto vector storage implementation.""" + + def connect(self, **kwargs: Any) -> Any: + """ + Connect to the vector storage. + + Args: + **kwargs: Arbitrary keyword arguments containing connection parameters. + - cluster (str): The Kusto cluster URL. + - database (str): The Kusto database name. + - client_id (str): The client ID for AAD authentication. + - client_secret (str): The client secret for AAD authentication. + - authority_id (str): The authority ID (tenant ID) for AAD authentication. + + Returns: + Any: The Kusto client instance. + """ + cluster = kwargs.get("cluster") + database = kwargs.get("database") + client_id = kwargs.get("client_id") + client_secret = kwargs.get("client_secret") + authority_id = kwargs.get("authority_id") + env = os.environ.get("ENVIRONMENT") + if(env == "AZURE"): + kcsb = KustoConnectionStringBuilder.with_aad_managed_service_identity_authentication( + str(cluster), client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3" + ) + elif(env == "DEVELOPMENT"): + kcsb = KustoConnectionStringBuilder.with_aad_device_authentication(str(cluster)) + else: + kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( + str(cluster), str(client_id), str(client_secret), str(authority_id)) + self.client = KustoClient(kcsb) + self.database = database + + def load_documents( + self, documents: List[VectorStoreDocument], overwrite: bool = True + ) -> None: + """ + Load documents into vector storage. + + Args: + documents (List[VectorStoreDocument]): List of documents to be loaded. + overwrite (bool): Whether to overwrite the existing table. Defaults to True. + """ + data = [ + { + "id": document.id, + "name": document.text, + "vector": document.vector, + "attributes": json.dumps(document.attributes), + } + for document in documents + if document.vector is not None + ] + + if len(data) == 0: + return + + # Convert data to DataFrame + df = pd.DataFrame(data) + + # Create or replace table + if overwrite: + command = f".drop table {self.collection_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.collection_name} (id: string, text: string, vector: dynamic, attributes: string)" + self.client.execute(self.database, command) + + # Ingest data + ingestion_command = f".ingest inline into table {self.collection_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + def filter_by_id(self, include_ids: List[str] | List[int]) -> Any: + """ + Build a query filter to filter documents by id. + + Args: + include_ids (List[str] | List[int]): List of document IDs to include in the filter. + + Returns: + Any: The query filter string. + """ + if len(include_ids) == 0: + self.query_filter = None + else: + if isinstance(include_ids[0], str): + id_filter = ", ".join([f"'{id}'" for id in include_ids]) + self.query_filter = f"id in ({id_filter})" + else: + self.query_filter = ( + f"id in ({', '.join([str(id) for id in include_ids])})" + ) + return self.query_filter + + def similarity_search_by_vector( + self, query_embedding: List[float], k: int = 10, **kwargs: Any + ) -> List[VectorStoreSearchResult]: + """ + Perform a vector-based similarity search. A search to find the k nearest neighbors of the given query vector. + + Args: + query_embedding (List[float]): The query embedding vector. + k (int): The number of top results to return. Defaults to 10. + **kwargs: Additional keyword arguments. + + Returns: + List[VectorStoreSearchResult]: List of search results. + """ + query = f""" + let query_vector = dynamic({query_embedding}); + {self.collection_name} + | extend similarity = series_cosine_similarity(query_vector, {self.vector_name}) + | top {k} by similarity desc + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + print("Similarities of the search results:", [row["similarity"] for _, row in df.iterrows()]) + + # Temporary to support the original entity_description_embedding + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=row["id"], + text=row["text"], + vector=row[self.vector_name], + attributes=row["attributes"], + ), + score= 1 + float(row["similarity"]), # 1 + similarity to make it a score between 0 and 2 + ) + for _, row in df.iterrows() + ] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """ + Perform a similarity search using a given input text. + + Args: + text (str): The input text to search for. + text_embedder (TextEmbedder): The text embedder to convert text to vector. + k (int): The number of top results to return. Defaults to 10. + **kwargs: Additional keyword arguments. + + Returns: + List[VectorStoreSearchResult]: List of search results. + """ + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector(query_embedding, k) + return [] + + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + query_embedding = text_embedder(text) + query = f""" + let query_vector = dynamic({query_embedding}); + {self.collection_name} + | extend similarity = series_cosine_similarity(query_vector, {self.vector_name}) + | top {k} by similarity desc + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + + return [ + Entity( + id=row["id"], + title=row["title"], + type=row["type"], + description=row["description"], + graph_embedding=row["graph_embedding"], + text_unit_ids=row["text_unit_ids"], + description_embedding=row["description_embedding"], + short_id="", + community_ids=row["community_ids"], + document_ids=row["document_ids"], + rank=row["rank"], + attributes=row["attributes"], + ) for _, row in df.iterrows() + ] + + def unload_entities(self) -> None: + self.client.execute(self.database,f".drop table {self.collection_name} ifexists") + self.client.execute(self.database,f".drop table {self.text_units_name} ifexists") + self.client.execute(self.database,f".drop table {self.reports_name} ifexists") + + def setup_entities(self) -> None: + command = f".drop table {self.collection_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.collection_name} (id: string, short_id: real, title: string, type: string, description: string, description_embedding: dynamic, name_embedding: dynamic, graph_embedding: dynamic, community_ids: dynamic, text_unit_ids: dynamic, document_ids: dynamic, rank: real, attributes: dynamic)" + self.client.execute(self.database, command) + command = f".alter column {self.collection_name}.graph_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + command = f".alter column {self.collection_name}.description_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + + def load_entities(self, entities: list[Entity], overwrite: bool = False) -> None: + # Convert data to DataFrame + df = pd.DataFrame(entities) + + # Create or replace table + if overwrite: + self.setup_entities() + + # Ingest data + ingestion_command = f".ingest inline into table {self.collection_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + def setup_reports(self) -> None: + command = f".drop table {self.reports_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.reports_name} (id: string, short_id: string, title: string, community_id: string, summary: string, full_content: string, rank: real, summary_embedding: dynamic, full_content_embedding: dynamic, attributes: dynamic)" + self.client.execute(self.database, command) + command = f".alter column {self.reports_name}.summary_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + command = f".alter column {self.reports_name}.full_content_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + + def load_reports(self, reports: list[CommunityReport], overwrite: bool = False) -> None: + # Convert data to DataFrame + df = pd.DataFrame(reports) + + # Create or replace table + if overwrite: + self.setup_reports() + + # Ingest data + ingestion_command = f".ingest inline into table {self.reports_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + def setup_text_units(self) -> None: + command = f".drop table {self.text_units_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.text_units_name} (id: string, text: string, n_tokens: string, document_ids: string, entity_ids: string, relationship_ids: string)" + self.client.execute(self.database, command) + + + def load_text_units(self, units: list[TextUnit], overwrite: bool = False) -> None: + df = pd.DataFrame(units) + if overwrite: + self.setup_text_units() + + ingestion_command = f".ingest inline into table {self.text_units_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + def get_extracted_reports( + self, community_ids: list[int], **kwargs: Any + ) -> list[CommunityReport]: + community_ids = ", ".join([str(id) for id in community_ids]) + query = f""" + {self.reports_name} + | where community_id in ({community_ids}) + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + + return [ + CommunityReport( + id=row["id"], + short_id=row["short_id"], + title=row["title"], + community_id=row["community_id"], + summary=row["summary"], + full_content=row["full_content"], + rank=row["rank"], + summary_embedding=row["summary_embedding"], + full_content_embedding=row["full_content_embedding"], + attributes=row["attributes"], + ) for _, row in df.iterrows() + ] diff --git a/func-app/graphrag/vector_stores/lancedb.py b/func-app/graphrag/vector_stores/lancedb.py new file mode 100644 index 0000000000..fb6447d407 --- /dev/null +++ b/func-app/graphrag/vector_stores/lancedb.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The LanceDB vector storage implementation package.""" + +import lancedb as lancedb # noqa: I001 (Ruff was breaking on this file imports, even tho they were sorted and passed local tests) +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity +from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit + +import json +from typing import Any + +import pyarrow as pa + +from .base import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + + +class LanceDBVectorStore(BaseVectorStore): + """The LanceDB vector storage implementation.""" + + def connect(self, **kwargs: Any) -> Any: + """Connect to the vector storage.""" + db_uri = kwargs.get("db_uri", "./lancedb") + self.db_connection = lancedb.connect(db_uri) # type: ignore + + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + """Load documents into vector storage.""" + data = [ + { + "id": document.id, + "text": document.text, + "vector": document.vector, + "attributes": json.dumps(document.attributes), + } + for document in documents + if document.vector is not None + ] + + if len(data) == 0: + data = None + + schema = pa.schema([ + pa.field("id", pa.string()), + pa.field("text", pa.string()), + pa.field("vector", pa.list_(pa.float64())), + pa.field("attributes", pa.string()), + ]) + if overwrite: + if data: + self.document_collection = self.db_connection.create_table( + self.collection_name, data=data, mode="overwrite" + ) + else: + self.document_collection = self.db_connection.create_table( + self.collection_name, schema=schema, mode="overwrite" + ) + else: + # add data to existing table + self.document_collection = self.db_connection.open_table( + self.collection_name + ) + if data: + self.document_collection.add(data) + + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by id.""" + if len(include_ids) == 0: + self.query_filter = None + else: + if isinstance(include_ids[0], str): + id_filter = ", ".join([f"'{id}'" for id in include_ids]) + self.query_filter = f"id in ({id_filter})" + else: + self.query_filter = ( + f"id in ({', '.join([str(id) for id in include_ids])})" + ) + return self.query_filter + + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a vector-based similarity search.""" + if self.query_filter: + docs = ( + self.document_collection.search(query=query_embedding) + .where(self.query_filter, prefilter=True) + .limit(k) + .to_list() + ) + else: + docs = ( + self.document_collection.search(query=query_embedding) + .limit(k) + .to_list() + ) + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=doc["id"], + text=doc["text"], + vector=doc["vector"], + attributes=json.loads(doc["attributes"]), + ), + score=1 - abs(float(doc["_distance"])), + ) + for doc in docs + ] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a similarity search using a given input text.""" + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector(query_embedding, k) + return [] + + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + raise NotImplementedError("Loading entities is not supported for LanceDB") + + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + raise NotImplementedError("Extracting entities is not supported for LanceDB") + + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + raise NotImplementedError("Loading reports is not supported for LanceDB") + + def load_text_units(self, units: list[TextUnit], overwrite: bool = True) -> None: + raise NotImplementedError("load_text_units(): Unsupported for this vector store.") + + def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: + raise NotImplementedError("Extracting community reports is not supported for LanceDB") + + def setup_entities(self) -> None: + raise NotImplementedError("Setting up entities is not supported for LanceDB") + + def setup_reports(self) -> None: + raise NotImplementedError("Setting up community reports is not supported for LanceDB") + + def setup_text_units(self) -> None: + raise NotImplementedError("setup_text_units(): Unsupported for this vector store.") + + def unload_entities(self) -> None: + raise NotImplementedError("unload_entities(): Unsupported for this vector store.") diff --git a/func-app/graphrag/vector_stores/typing.py b/func-app/graphrag/vector_stores/typing.py new file mode 100644 index 0000000000..459d5b5f56 --- /dev/null +++ b/func-app/graphrag/vector_stores/typing.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the supported vector store types.""" + +from enum import Enum +from typing import ClassVar + +from .azure_ai_search import AzureAISearch +from .lancedb import LanceDBVectorStore +from .kusto import KustoVectorStore + + +class VectorStoreType(str, Enum): + """The supported vector store types.""" + + LanceDB = "lancedb" + AzureAISearch = "azure_ai_search" + Kusto = "kusto" + + +class VectorStoreFactory: + """A factory class for creating vector stores.""" + + vector_store_types: ClassVar[dict[str, type]] = {} + + @classmethod + def register(cls, vector_store_type: str, vector_store: type): + """Register a vector store type.""" + cls.vector_store_types[vector_store_type] = vector_store + + @classmethod + def get_vector_store( + cls, vector_store_type: VectorStoreType | str, kwargs: dict + ) -> LanceDBVectorStore | AzureAISearch | KustoVectorStore: + """Get the vector store type from a string.""" + match vector_store_type: + case VectorStoreType.LanceDB: + return LanceDBVectorStore(**kwargs) + case VectorStoreType.AzureAISearch: + return AzureAISearch(**kwargs) + case VectorStoreType.Kusto: + return KustoVectorStore(**kwargs) + case _: + if vector_store_type in cls.vector_store_types: + return cls.vector_store_types[vector_store_type](**kwargs) + msg = f"Unknown vector store type: {vector_store_type}" + raise ValueError(msg) diff --git a/func-app/host.json b/func-app/host.json new file mode 100644 index 0000000000..9df913614d --- /dev/null +++ b/func-app/host.json @@ -0,0 +1,15 @@ +{ + "version": "2.0", + "logging": { + "applicationInsights": { + "samplingSettings": { + "isEnabled": true, + "excludedTypes": "Request" + } + } + }, + "extensionBundle": { + "id": "Microsoft.Azure.Functions.ExtensionBundle", + "version": "[4.*, 5.0.0)" + } +} \ No newline at end of file diff --git a/func-app/prompts/claim_extraction.txt b/func-app/prompts/claim_extraction.txt new file mode 100644 index 0000000000..0b795c3465 --- /dev/null +++ b/func-app/prompts/claim_extraction.txt @@ -0,0 +1,52 @@ + +-Target activity- +You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document. + +-Goal- +Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities. + +-Steps- +1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types. +2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim. +For each claim, extract the following information: +- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1. +- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**. +- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type +- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified. +- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references. +- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**. +- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim. + +Format each claim as ({tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in English as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +-Examples- +Example 1: +Entity specification: organization +Claim description: red flags associated with an entity +Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. +Output: + +(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +{completion_delimiter} + +Example 2: +Entity specification: Company A, Person C +Claim description: red flags associated with an entity +Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. +Output: + +(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +{record_delimiter} +(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015) +{completion_delimiter} + +-Real Data- +Use the following input for your answer. +Entity specification: {entity_specs} +Claim description: {claim_description} +Text: {input_text} +Output: \ No newline at end of file diff --git a/func-app/prompts/community_report.txt b/func-app/prompts/community_report.txt new file mode 100644 index 0000000000..d71440ab2f --- /dev/null +++ b/func-app/prompts/community_report.txt @@ -0,0 +1,146 @@ + +You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network. + +# Goal +Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims. + +# Report Structure + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. +- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format: + {{ + "title": , + "summary": , + "rating": , + "rating_explanation": , + "findings": [ + {{ + "summary":, + "explanation": + }}, + {{ + "summary":, + "explanation": + }} + ] + }} + +# Grounding Rules + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." + +where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +# Example Input +----------- +Text: + +Entities + +id,entity,description +5,VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March +6,HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza + +Relationships + +id,source,target,description +37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March +38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza +39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza +40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza +41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march +43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March + +Output: +{{ + "title": "Verdant Oasis Plaza and Unity March", + "summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.", + "rating": 5.0, + "rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.", + "findings": [ + {{ + "summary": "Verdant Oasis Plaza as the central location", + "explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes. [Data: Entities (5), Relationships (37, 38, 39, 40, 41,+more)]" + }}, + {{ + "summary": "Harmony Assembly's role in the community", + "explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community. [Data: Entities(6), Relationships (38, 43)]" + }}, + {{ + "summary": "Unity March as a significant event", + "explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community. [Data: Relationships (39)]" + }}, + {{ + "summary": "Role of Tribune Spotlight", + "explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved. [Data: Relationships (40)]" + }} + ] +}} + + +# Real Data + +Use the following text for your answer. Do not make anything up in your answer. + +Text: +{input_text} + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. +- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format: + {{ + "title": , + "summary": , + "rating": , + "rating_explanation": , + "findings": [ + {{ + "summary":, + "explanation": + }}, + {{ + "summary":, + "explanation": + }} + ] + }} + +# Grounding Rules + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." + +where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + +Output: \ No newline at end of file diff --git a/func-app/prompts/entity_extraction.txt b/func-app/prompts/entity_extraction.txt new file mode 100644 index 0000000000..d47747f7cf --- /dev/null +++ b/func-app/prompts/entity_extraction.txt @@ -0,0 +1,99 @@ + +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter} + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity + Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +###################### +-Examples- +###################### +Example 1: + +Entity_types: [person, technology, mission, organization, location] +Text: +while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. + +Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.” + +The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. + +It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths +################ +Output: +("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter} +("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter} +("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter} +("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter} +("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}7){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}6){record_delimiter} +("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}8){record_delimiter} +("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}5){record_delimiter} +("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}9){completion_delimiter} +############################# +Example 2: + +Entity_types: [person, technology, mission, organization, location] +Text: +They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve. + +Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril. + +Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly +############# +Output: +("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter} +("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter} +("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter} +("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}7){record_delimiter} +("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}9){completion_delimiter} +############################# +Example 3: + +Entity_types: [person, role, technology, organization, event, location, concept] +Text: +their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data. + +"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning." + +Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back." + +Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history. + +The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation +############# +Output: +("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter} +("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter} +("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter} +("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter} +("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter} +("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter} +("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}9){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}10){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}8){record_delimiter} +("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}7){completion_delimiter} +############################# +-Real Data- +###################### +Entity_types: {entity_types} +Text: {input_text} +###################### +Output: \ No newline at end of file diff --git a/func-app/prompts/summarize_descriptions.txt b/func-app/prompts/summarize_descriptions.txt new file mode 100644 index 0000000000..54feaf32d4 --- /dev/null +++ b/func-app/prompts/summarize_descriptions.txt @@ -0,0 +1,13 @@ + +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Entities: {entity_name} +Description List: {description_list} +####### +Output: diff --git a/func-app/requirements.txt b/func-app/requirements.txt new file mode 100644 index 0000000000..8f1a0e42e2 --- /dev/null +++ b/func-app/requirements.txt @@ -0,0 +1,142 @@ +aenum==3.1.15 +aiofiles==24.1.0 +aiohappyeyeballs==2.4.0 +aiohttp==3.10.5 +aiolimiter==1.1.0 +aiosignal==1.3.1 +annotated-types==0.7.0 +anyio==4.4.0 +anytree==2.12.1 +asttokens==2.4.1 +async-timeout==4.0.3 +attrs==24.2.0 +autograd==1.7.0 +azure-common==1.1.28 +azure-core==1.31.0 +azure-cosmos==4.7.0 +azure-identity==1.17.1 +azure-kusto-data==4.5.1 +azure-search-documents==11.5.1 +azure-storage-blob==12.23.0 +beartype==0.18.5 +cachetools==5.5.0 +certifi==2024.8.30 +cffi==1.17.1 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpickle==3.0.0 +colorama==0.4.6 +contourpy==1.3.0 +cramjam==2.8.3 +cryptography==43.0.1 +cycler==0.12.1 +dask-expr==1.1.14 +dask==2024.9.0 +dask[dataframe]==2024.9.0 +datashaper==0.0.49 +decorator==5.1.1 +deprecation==2.1.0 +devtools==0.12.2 +diskcache==5.6.3 +distro==1.9.0 +environs==11.0.0 +exceptiongroup==1.2.2 +executing==2.1.0 +fastparquet==2024.5.0 +fonttools==4.53.1 +frozenlist==1.4.1 +fsspec==2024.9.0 +gensim==4.3.3 +graspologic-native==1.2.1 +graspologic==3.4.1 +gremlinpython==3.7.2 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.2 +hyppo==0.4.0 +idna==3.10 +ijson==3.3.0 +importlib-metadata==8.5.0 +isodate==0.6.1 +jiter==0.5.0 +joblib==1.4.2 +json-repair==0.25.3 +jsonschema-specifications==2023.12.1 +jsonschema==4.23.0 +kiwisolver==1.4.7 +lancedb==0.11.0 +linkify-it-py==2.0.3 +llvmlite==0.43.0 +locket==1.0.0 +markdown-it-py==3.0.0 +markdown-it-py[linkify,plugins]==3.0.0 +marshmallow==3.22.0 +matplotlib==3.9.2 +mdit-py-plugins==0.4.2 +mdurl==0.1.2 +msal-extensions==1.2.0 +msal==1.31.0 +multidict==6.1.0 +nest-asyncio==1.6.0 +networkx==3.3 +nltk==3.8.1 +numba==0.60.0 +numpy==1.26.4 +openai==1.46.0 +overrides==7.7.0 +packaging==24.1 +pandas==2.2.2 +partd==1.4.2 +patsy==0.5.6 +pillow==10.4.0 +portalocker==2.10.1 +pot==0.9.4 +psutil==6.0.0 +py==1.11.0 +pyaml-env==1.2.1 +pyarrow==15.0.2 +pycparser==2.22 +pydantic-core==2.23.4 +pydantic==2.9.2 +pygments==2.18.0 +pyjwt[crypto]==2.9.0 +pylance==0.15.0 +pynndescent==0.5.13 +pyparsing==3.1.4 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +pytz==2024.2 +pyyaml==6.0.2 +ratelimiter==1.2.0.post0 +referencing==0.35.1 +regex==2024.9.11 +requests==2.32.3 +retry==0.9.2 +rich==13.8.1 +rpds-py==0.20.0 +scikit-learn==1.5.2 +scipy==1.12.0 +seaborn==0.13.2 +six==1.16.0 +smart-open==7.0.4 +sniffio==1.3.1 +statsmodels==0.14.3 +swifter==1.4.0 +tenacity==8.5.0 +textual==0.74.0 +threadpoolctl==3.5.0 +tiktoken==0.7.0 +toolz==0.12.1 +tqdm==4.66.5 +typing-extensions==4.12.2 +tzdata==2024.1 +uc-micro-py==1.0.3 +umap-learn==0.5.6 +urllib3==2.2.3 +wrapt==1.16.0 +yarl==1.11.1 +zipp==3.20.2 +azure.functions +future +pandas +PyYAML \ No newline at end of file diff --git a/func-app/settings.yaml b/func-app/settings.yaml new file mode 100644 index 0000000000..3efcdd6bbd --- /dev/null +++ b/func-app/settings.yaml @@ -0,0 +1,165 @@ + +encoding_model: cl100k_base +skip_workflows: [] +llm: + api_key: ${GRAPHRAG_API_KEY} + type: openai_chat # or azure_openai_chat + model: gpt-4-turbo-preview + model_supports_json: true # recommended if this is available for your model. + # max_tokens: 4000 + # request_timeout: 180.0 + # api_base: https://.openai.azure.com + # api_version: 2024-02-15-preview + # organization: + # deployment_name: + # tokens_per_minute: 150_000 # set a leaky bucket throttle + # requests_per_minute: 10_000 # set a leaky bucket throttle + # max_retries: 10 + # max_retry_wait: 10.0 + # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times + # concurrent_requests: 25 # the number of parallel inflight requests that may be made + # temperature: 0 # temperature for sampling + # top_p: 1 # top-p sampling + # n: 1 # Number of completions to generate + +parallelization: + stagger: 0.3 + # num_threads: 50 # the number of threads to use for parallel processing + +async_mode: threaded # or asyncio + +embeddings: + ## parallelization: override the global parallelization settings for embeddings + async_mode: threaded # or asyncio + llm: + api_key: ${GRAPHRAG_API_KEY} + type: openai_embedding # or azure_openai_embedding + model: text-embedding-3-small + # api_base: https://.openai.azure.com + # api_version: 2024-02-15-preview + # organization: + # deployment_name: + # tokens_per_minute: 150_000 # set a leaky bucket throttle + # requests_per_minute: 10_000 # set a leaky bucket throttle + # max_retries: 10 + # max_retry_wait: 10.0 + # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times + # concurrent_requests: 25 # the number of parallel inflight requests that may be made + # batch_size: 16 # the number of documents to send in a single request + # batch_max_tokens: 8191 # the maximum number of tokens to send in a single request + # target: required # or optional + + + +chunks: + size: 1200 + overlap: 100 + group_by_columns: [id] # by default, we don't allow chunks to cross documents + +input: + type: file # or blob + file_type: text # or csv + base_dir: "input" + file_encoding: utf-8 + file_pattern: ".*\\.txt$" + +cache: + type: file # or blob + base_dir: "cache" + # connection_string: + # container_name: + +storage: + type: file # or blob + base_dir: "output/${timestamp}/artifacts" + # connection_string: + # container_name: + +reporting: + type: file # or console, blob + base_dir: "output/${timestamp}/reports" + # connection_string: + # container_name: + +entity_extraction: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/entity_extraction.txt" + entity_types: [organization,person,geo,event] + max_gleanings: 1 + +summarize_descriptions: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/summarize_descriptions.txt" + max_length: 500 + +claim_extraction: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + # enabled: true + prompt: "prompts/claim_extraction.txt" + description: "Any claims or facts that could be relevant to information discovery." + max_gleanings: 1 + +community_reports: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/community_report.txt" + max_length: 2000 + max_input_length: 8000 + +cluster_graph: + max_cluster_size: 10 + +embed_graph: + enabled: false # if true, will generate node2vec embeddings for nodes + # num_walks: 10 + # walk_length: 40 + # window_size: 2 + # iterations: 3 + # random_seed: 597832 + +umap: + enabled: false # if true, will generate UMAP embeddings for nodes + +snapshots: + graphml: false + raw_entities: false + top_level_nodes: false + +local_search: + # text_unit_prop: 0.5 + # community_prop: 0.1 + # conversation_history_max_turns: 5 + # top_k_mapped_entities: 10 + # top_k_relationships: 10 + # llm_temperature: 0 # temperature for sampling + # llm_top_p: 1 # top-p sampling + # llm_n: 1 # Number of completions to generate + # max_tokens: 12000 + +global_search: + # llm_temperature: 0 # temperature for sampling + # llm_top_p: 1 # top-p sampling + # llm_n: 1 # Number of completions to generate + # max_tokens: 12000 + # data_max_tokens: 12000 + # map_max_tokens: 1000 + # reduce_max_tokens: 2000 + # concurrency: 32 + +query_context: + # Files: [] # list of files in context to run query + +graphdb: + account_name: '' + account_key: '' + username: '' + enabled: false + cosmos_url: '' + gremlin_url: '' diff --git a/func-app/settings/settings.yaml b/func-app/settings/settings.yaml new file mode 100644 index 0000000000..1e6a36da3e --- /dev/null +++ b/func-app/settings/settings.yaml @@ -0,0 +1,152 @@ + +encoding_model: cl100k_base +skip_workflows: [] +llm: + api_key: ${GRAPHRAG_API_KEY} + type: azure_openai_chat # openai_chat # or azure_openai_chat + model: gpt-4o + model_supports_json: true # recommended if this is available for your model. + # max_tokens: 4000 + # request_timeout: 180.0 + api_base: https://spe-pds-ais-ai.openai.azure.com + api_version: 2024-04-01-preview + # organization: + deployment_name: spepdsaigpt + # tokens_per_minute: 150_000 # set a leaky bucket throttle + # requests_per_minute: 10_000 # set a leaky bucket throttle + # max_retries: 10 + # max_retry_wait: 10.0 + # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times + # concurrent_requests: 25 # the number of parallel inflight requests that may be made + # temperature: 0 # temperature for sampling + # top_p: 1 # top-p sampling + # n: 1 # Number of completions to generate + +parallelization: + stagger: 0.3 + # num_threads: 50 # the number of threads to use for parallel processing + +async_mode: threaded # or asyncio + +embeddings: + ## parallelization: override the global parallelization settings for embeddings + async_mode: threaded # or asyncio + llm: + api_key: ${GRAPHRAG_API_KEY} + type: azure_openai_embedding #openai_embedding # or azure_openai_embedding + model: text-embedding-ada-002 + api_base: https://spe-pds-ais-ai.openai.azure.com + api_version: 2024-04-01-preview + # organization: + deployment_name: spepdsaista + # tokens_per_minute: 150_000 # set a leaky bucket throttle + # requests_per_minute: 10_000 # set a leaky bucket throttle + # max_retries: 10 + # max_retry_wait: 10.0 + # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times + # concurrent_requests: 25 # the number of parallel inflight requests that may be made + # batch_size: 16 # the number of documents to send in a single request + # batch_max_tokens: 8191 # the maximum number of tokens to send in a single request + # target: required # or optional + +chunks: + size: 500 + overlap: 100 + group_by_columns: [id] # by default, we don't allow chunks to cross documents + +input: + type: file # or blob + file_type: text # or csv + base_dir: "input" + file_encoding: utf-8 + file_pattern: ".*\\.txt$" + +cache: + type: file # or blob + base_dir: "cache" + # connection_string: + # container_name: + +storage: + type: file # or blob + base_dir: "output/${timestamp}/artifacts" + # connection_string: + # container_name: + +reporting: + type: file # or console, blob + base_dir: "output/${timestamp}/reports" + # connection_string: + # container_name: + +entity_extraction: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/entity_extraction.txt" + entity_types: [organization,person,geo,event] + max_gleanings: 1 + +summarize_descriptions: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/summarize_descriptions.txt" + max_length: 500 + +claim_extraction: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + # enabled: true + prompt: "prompts/claim_extraction.txt" + description: "Any claims or facts that could be relevant to information discovery." + max_gleanings: 1 + +community_reports: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/community_report.txt" + max_length: 2000 + max_input_length: 8000 + +cluster_graph: + max_cluster_size: 10 + +embed_graph: + enabled: true # if true, will generate node2vec embeddings for nodes + # num_walks: 10 + # walk_length: 40 + # window_size: 2 + # iterations: 3 + # random_seed: 597832 + +umap: + enabled: false # if true, will generate UMAP embeddings for nodes + +snapshots: + graphml: false + raw_entities: false + top_level_nodes: false + +local_search: + # text_unit_prop: 0.5 + # community_prop: 0.1 + # conversation_history_max_turns: 5 + # top_k_mapped_entities: 10 + # top_k_relationships: 10 + # llm_temperature: 0 # temperature for sampling + # llm_top_p: 1 # top-p sampling + # llm_n: 1 # Number of completions to generate + # max_tokens: 12000 + +global_search: + # llm_temperature: 0 # temperature for sampling + # llm_top_p: 1 # top-p sampling + # llm_n: 1 # Number of completions to generate + # max_tokens: 12000 + # data_max_tokens: 12000 + # map_max_tokens: 1000 + # reduce_max_tokens: 2000 + # concurrency: 32 From 7a87843b14833b16ab352226aa716565e46abbd7 Mon Sep 17 00:00:00 2001 From: Prateek Jain Date: Sun, 22 Sep 2024 22:18:57 -0700 Subject: [PATCH 80/87] Added one req for windows local debug --- func-app/requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/func-app/requirements.txt b/func-app/requirements.txt index 8f1a0e42e2..4b51af40be 100644 --- a/func-app/requirements.txt +++ b/func-app/requirements.txt @@ -106,6 +106,7 @@ pyparsing==3.1.4 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 pytz==2024.2 +pywin32==306 ## This needs to be commented when deploying to the func app. pyyaml==6.0.2 ratelimiter==1.2.0.post0 referencing==0.35.1 @@ -138,5 +139,4 @@ yarl==1.11.1 zipp==3.20.2 azure.functions future -pandas -PyYAML \ No newline at end of file +pandas \ No newline at end of file From 3c1781a1a0c3ba7c01bd64044b0cf810b656ecfe Mon Sep 17 00:00:00 2001 From: Prateek Jain Date: Sun, 22 Sep 2024 22:42:55 -0700 Subject: [PATCH 81/87] Removing the redundant settings.yaml --- func-app/settings.yaml | 165 ----------------------------------------- 1 file changed, 165 deletions(-) delete mode 100644 func-app/settings.yaml diff --git a/func-app/settings.yaml b/func-app/settings.yaml deleted file mode 100644 index 3efcdd6bbd..0000000000 --- a/func-app/settings.yaml +++ /dev/null @@ -1,165 +0,0 @@ - -encoding_model: cl100k_base -skip_workflows: [] -llm: - api_key: ${GRAPHRAG_API_KEY} - type: openai_chat # or azure_openai_chat - model: gpt-4-turbo-preview - model_supports_json: true # recommended if this is available for your model. - # max_tokens: 4000 - # request_timeout: 180.0 - # api_base: https://.openai.azure.com - # api_version: 2024-02-15-preview - # organization: - # deployment_name: - # tokens_per_minute: 150_000 # set a leaky bucket throttle - # requests_per_minute: 10_000 # set a leaky bucket throttle - # max_retries: 10 - # max_retry_wait: 10.0 - # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times - # concurrent_requests: 25 # the number of parallel inflight requests that may be made - # temperature: 0 # temperature for sampling - # top_p: 1 # top-p sampling - # n: 1 # Number of completions to generate - -parallelization: - stagger: 0.3 - # num_threads: 50 # the number of threads to use for parallel processing - -async_mode: threaded # or asyncio - -embeddings: - ## parallelization: override the global parallelization settings for embeddings - async_mode: threaded # or asyncio - llm: - api_key: ${GRAPHRAG_API_KEY} - type: openai_embedding # or azure_openai_embedding - model: text-embedding-3-small - # api_base: https://.openai.azure.com - # api_version: 2024-02-15-preview - # organization: - # deployment_name: - # tokens_per_minute: 150_000 # set a leaky bucket throttle - # requests_per_minute: 10_000 # set a leaky bucket throttle - # max_retries: 10 - # max_retry_wait: 10.0 - # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times - # concurrent_requests: 25 # the number of parallel inflight requests that may be made - # batch_size: 16 # the number of documents to send in a single request - # batch_max_tokens: 8191 # the maximum number of tokens to send in a single request - # target: required # or optional - - - -chunks: - size: 1200 - overlap: 100 - group_by_columns: [id] # by default, we don't allow chunks to cross documents - -input: - type: file # or blob - file_type: text # or csv - base_dir: "input" - file_encoding: utf-8 - file_pattern: ".*\\.txt$" - -cache: - type: file # or blob - base_dir: "cache" - # connection_string: - # container_name: - -storage: - type: file # or blob - base_dir: "output/${timestamp}/artifacts" - # connection_string: - # container_name: - -reporting: - type: file # or console, blob - base_dir: "output/${timestamp}/reports" - # connection_string: - # container_name: - -entity_extraction: - ## llm: override the global llm settings for this task - ## parallelization: override the global parallelization settings for this task - ## async_mode: override the global async_mode settings for this task - prompt: "prompts/entity_extraction.txt" - entity_types: [organization,person,geo,event] - max_gleanings: 1 - -summarize_descriptions: - ## llm: override the global llm settings for this task - ## parallelization: override the global parallelization settings for this task - ## async_mode: override the global async_mode settings for this task - prompt: "prompts/summarize_descriptions.txt" - max_length: 500 - -claim_extraction: - ## llm: override the global llm settings for this task - ## parallelization: override the global parallelization settings for this task - ## async_mode: override the global async_mode settings for this task - # enabled: true - prompt: "prompts/claim_extraction.txt" - description: "Any claims or facts that could be relevant to information discovery." - max_gleanings: 1 - -community_reports: - ## llm: override the global llm settings for this task - ## parallelization: override the global parallelization settings for this task - ## async_mode: override the global async_mode settings for this task - prompt: "prompts/community_report.txt" - max_length: 2000 - max_input_length: 8000 - -cluster_graph: - max_cluster_size: 10 - -embed_graph: - enabled: false # if true, will generate node2vec embeddings for nodes - # num_walks: 10 - # walk_length: 40 - # window_size: 2 - # iterations: 3 - # random_seed: 597832 - -umap: - enabled: false # if true, will generate UMAP embeddings for nodes - -snapshots: - graphml: false - raw_entities: false - top_level_nodes: false - -local_search: - # text_unit_prop: 0.5 - # community_prop: 0.1 - # conversation_history_max_turns: 5 - # top_k_mapped_entities: 10 - # top_k_relationships: 10 - # llm_temperature: 0 # temperature for sampling - # llm_top_p: 1 # top-p sampling - # llm_n: 1 # Number of completions to generate - # max_tokens: 12000 - -global_search: - # llm_temperature: 0 # temperature for sampling - # llm_top_p: 1 # top-p sampling - # llm_n: 1 # Number of completions to generate - # max_tokens: 12000 - # data_max_tokens: 12000 - # map_max_tokens: 1000 - # reduce_max_tokens: 2000 - # concurrency: 32 - -query_context: - # Files: [] # list of files in context to run query - -graphdb: - account_name: '' - account_key: '' - username: '' - enabled: false - cosmos_url: '' - gremlin_url: '' From 05ee3b9283d888ad36ea5ca1bdea63062caa19c0 Mon Sep 17 00:00:00 2001 From: prateejain-linked Date: Sun, 22 Sep 2024 22:44:34 -0700 Subject: [PATCH 82/87] Added the func app compatible code (#50) * Added the func app compatible code * Added one req for windows local debug * Removing the redundant settings.yaml --------- Co-authored-by: Prateek Jain --- func-app/.gitignore | 49 ++ func-app/.vscode/launch.json | 13 + func-app/.vscode/settings.json | 8 + func-app/.vscode/tasks.json | 26 + func-app/common/graph_db_client.py | 158 ++++ func-app/function_app.py | 37 + func-app/graphrag/__init__.py | 4 + .../graphrag/common/blob_storage_client.py | 58 ++ func-app/graphrag/common/config/storage.py | 72 ++ func-app/graphrag/common/graph_db_client.py | 1 + func-app/graphrag/common/kusto_db_client.py | 1 + func-app/graphrag/common/progress/__init__.py | 8 + func-app/graphrag/common/progress/rich.py | 165 +++++ func-app/graphrag/common/progress/types.py | 128 ++++ func-app/graphrag/common/storage/__init__.py | 19 + .../common/storage/blob_pipeline_storage.py | 375 ++++++++++ .../common/storage/file_pipeline_storage.py | 166 +++++ .../graphrag/common/storage/load_storage.py | 40 + .../common/storage/memory_pipeline_storage.py | 79 ++ func-app/graphrag/common/storage/typing.py | 80 ++ .../graphrag/common/utils/common_utils.py | 11 + .../graphrag/common/utils/context_utils.py | 9 + func-app/graphrag/config/__init__.py | 128 ++++ .../graphrag/config/create_graphrag_config.py | 687 ++++++++++++++++++ func-app/graphrag/config/defaults.py | 106 +++ func-app/graphrag/config/enums.py | 127 ++++ .../graphrag/config/environment_reader.py | 155 ++++ func-app/graphrag/config/errors.py | 40 + .../graphrag/config/input_models/__init__.py | 50 ++ .../config/input_models/cache_config_input.py | 18 + .../input_models/chunking_config_input.py | 15 + .../claim_extraction_config_input.py | 19 + .../cluster_graph_config_input.py | 13 + .../community_reports_config_input.py | 17 + .../input_models/embed_graph_config_input.py | 18 + .../entity_extraction_config_input.py | 18 + .../global_search_config_input.py | 16 + .../input_models/graphrag_config_input.py | 49 ++ .../config/input_models/input_config_input.py | 27 + .../config/input_models/llm_config_input.py | 18 + .../input_models/llm_parameters_input.py | 31 + .../input_models/local_search_config_input.py | 18 + .../parallelization_parameters_input.py | 13 + .../query_context_config_input.py | 7 + .../input_models/reporting_config_input.py | 18 + .../input_models/snapshots_config_input.py | 14 + .../input_models/storage_config_input.py | 18 + .../summarize_descriptions_config_input.py | 16 + .../text_embedding_config_input.py | 23 + .../config/input_models/umap_config_input.py | 12 + func-app/graphrag/config/models/__init__.py | 52 ++ .../graphrag/config/models/cache_config.py | 29 + .../graphrag/config/models/chunking_config.py | 40 + .../config/models/claim_extraction_config.py | 57 ++ .../config/models/cluster_graph_config.py | 28 + .../config/models/community_reports_config.py | 48 ++ .../config/models/embed_graph_config.py | 48 ++ .../config/models/entity_extraction_config.py | 53 ++ .../config/models/global_search_config.py | 45 ++ .../config/models/graph_rag_config.py | 158 ++++ .../graphrag/config/models/graphdb_config.py | 37 + .../graphrag/config/models/input_config.py | 60 ++ func-app/graphrag/config/models/llm_config.py | 27 + .../graphrag/config/models/llm_parameters.py | 87 +++ .../config/models/local_search_config.py | 51 ++ .../models/parallelization_parameters.py | 21 + .../config/models/query_context_config.py | 16 + .../config/models/reporting_config.py | 30 + .../config/models/snapshots_config.py | 25 + .../graphrag/config/models/storage_config.py | 33 + .../models/summarize_descriptions_config.py | 43 ++ .../config/models/text_embedding_config.py | 46 ++ .../graphrag/config/models/umap_config.py | 17 + func-app/graphrag/config/read_dotenv.py | 25 + func-app/graphrag/index/__init__.py | 78 ++ func-app/graphrag/index/__main__.py | 125 ++++ func-app/graphrag/index/bootstrap.py | 28 + func-app/graphrag/index/cache/__init__.py | 18 + .../index/cache/json_pipeline_cache.py | 64 ++ func-app/graphrag/index/cache/load_cache.py | 51 ++ .../index/cache/memory_pipeline_cache.py | 83 +++ .../index/cache/noop_pipeline_cache.py | 65 ++ .../graphrag/index/cache/pipeline_cache.py | 67 ++ func-app/graphrag/index/cli.py | 356 +++++++++ func-app/graphrag/index/config/__init__.py | 69 ++ func-app/graphrag/index/config/cache.py | 82 +++ func-app/graphrag/index/config/input.py | 120 +++ func-app/graphrag/index/config/pipeline.py | 70 ++ func-app/graphrag/index/config/reporting.py | 77 ++ func-app/graphrag/index/config/workflow.py | 34 + func-app/graphrag/index/context.py | 42 ++ .../index/context_switch/contextSwitcher.py | 288 ++++++++ .../graphrag/index/create_pipeline_config.py | 595 +++++++++++++++ func-app/graphrag/index/emit/__init__.py | 21 + .../graphrag/index/emit/csv_table_emitter.py | 33 + func-app/graphrag/index/emit/factories.py | 46 ++ .../graphrag/index/emit/graph_db_emitter.py | 24 + .../graphrag/index/emit/json_table_emitter.py | 34 + .../index/emit/parquet_table_emitter.py | 54 ++ func-app/graphrag/index/emit/table_emitter.py | 15 + func-app/graphrag/index/emit/types.py | 15 + func-app/graphrag/index/errors.py | 25 + func-app/graphrag/index/graph/__init__.py | 4 + .../index/graph/embedding/__init__.py | 8 + .../index/graph/embedding/embedding.py | 41 ++ .../index/graph/extractors/__init__.py | 20 + .../index/graph/extractors/claims/__init__.py | 9 + .../extractors/claims/claim_extractor.py | 248 +++++++ .../index/graph/extractors/claims/prompts.py | 61 ++ .../extractors/community_reports/__init__.py | 35 + .../community_reports/build_mixed_context.py | 69 ++ .../community_reports_extractor.py | 107 +++ .../prep_community_report_context.py | 181 +++++ .../extractors/community_reports/prompts.py | 150 ++++ .../extractors/community_reports/schemas.py | 52 ++ .../community_reports/sort_context.py | 156 ++++ .../extractors/community_reports/utils.py | 53 ++ .../index/graph/extractors/graph/__init__.py | 18 + .../graph/extractors/graph/graph_extractor.py | 305 ++++++++ .../index/graph/extractors/graph/prompts.py | 129 ++++ .../graph/extractors/summarize/__init__.py | 12 + .../description_summary_extractor.py | 135 ++++ .../graph/extractors/summarize/prompts.py | 19 + .../graphrag/index/graph/utils/__init__.py | 9 + .../index/graph/utils/normalize_node_names.py | 14 + .../graphrag/index/graph/utils/stable_lcc.py | 60 ++ .../index/graph/visualization/__init__.py | 14 + .../visualization/compute_umap_positions.py | 144 ++++ .../index/graph/visualization/typing.py | 27 + func-app/graphrag/index/init_content.py | 176 +++++ func-app/graphrag/index/input/__init__.py | 8 + func-app/graphrag/index/input/csv.py | 138 ++++ func-app/graphrag/index/input/load_input.py | 84 +++ func-app/graphrag/index/input/text.py | 72 ++ func-app/graphrag/index/llm/__init__.py | 14 + func-app/graphrag/index/llm/load_llm.py | 313 ++++++++ func-app/graphrag/index/llm/types.py | 10 + .../graphrag/index/load_pipeline_config.py | 80 ++ func-app/graphrag/index/py.typed | 2 + func-app/graphrag/index/reporting/__init__.py | 18 + .../reporting/blob_workflow_callbacks.py | 108 +++ .../reporting/console_workflow_callbacks.py | 32 + .../reporting/file_workflow_callbacks.py | 67 ++ .../index/reporting/load_pipeline_reporter.py | 47 ++ .../reporting/progress_workflow_callbacks.py | 54 ++ func-app/graphrag/index/run.py | 471 ++++++++++++ .../graphrag/index/text_splitting/__init__.py | 34 + .../index/text_splitting/check_token_limit.py | 15 + .../index/text_splitting/text_splitting.py | 244 +++++++ func-app/graphrag/index/typing.py | 20 + func-app/graphrag/index/utils/__init__.py | 25 + func-app/graphrag/index/utils/dataframes.py | 61 ++ func-app/graphrag/index/utils/dicts.py | 18 + func-app/graphrag/index/utils/ds_util.py | 32 + func-app/graphrag/index/utils/hashing.py | 14 + func-app/graphrag/index/utils/is_null.py | 19 + func-app/graphrag/index/utils/load_graph.py | 11 + func-app/graphrag/index/utils/rate_limiter.py | 40 + func-app/graphrag/index/utils/string.py | 19 + func-app/graphrag/index/utils/tokens.py | 41 ++ .../graphrag/index/utils/topological_sort.py | 12 + func-app/graphrag/index/utils/uuid.py | 14 + func-app/graphrag/index/verbs/__init__.py | 50 ++ .../index/verbs/covariates/__init__.py | 8 + .../covariates/extract_covariates/__init__.py | 8 + .../extract_covariates/extract_covariates.py | 110 +++ .../extract_covariates/strategies/__init__.py | 4 + .../strategies/graph_intelligence/__init__.py | 8 + .../strategies/graph_intelligence/defaults.py | 21 + .../run_gi_extract_claims.py | 106 +++ .../graphrag/index/verbs/covariates/typing.py | 52 ++ .../graphrag/index/verbs/entities/__init__.py | 9 + .../verbs/entities/extraction/__init__.py | 8 + .../entities/extraction/entity_extract.py | 202 +++++ .../extraction/strategies/__init__.py | 4 + .../strategies/graph_intelligence/__init__.py | 8 + .../strategies/graph_intelligence/defaults.py | 25 + .../run_graph_intelligence.py | 142 ++++ .../entities/extraction/strategies/nltk.py | 61 ++ .../entities/extraction/strategies/typing.py | 44 ++ .../verbs/entities/summarize/__init__.py | 8 + .../summarize/description_summarize.py | 207 ++++++ .../entities/summarize/strategies/__init__.py | 8 + .../strategies/graph_intelligence/__init__.py | 8 + .../strategies/graph_intelligence/defaults.py | 17 + .../run_graph_intelligence.py | 70 ++ .../entities/summarize/strategies/typing.py | 34 + func-app/graphrag/index/verbs/genid.py | 66 ++ .../graphrag/index/verbs/graph/__init__.py | 36 + .../index/verbs/graph/clustering/__init__.py | 8 + .../verbs/graph/clustering/cluster_graph.py | 182 +++++ .../graph/clustering/strategies/__init__.py | 4 + .../graph/clustering/strategies/leiden.py | 69 ++ .../index/verbs/graph/clustering/typing.py | 6 + .../graph/compute_edge_combined_degree.py | 70 ++ func-app/graphrag/index/verbs/graph/create.py | 135 ++++ .../index/verbs/graph/embed/__init__.py | 8 + .../index/verbs/graph/embed/embed_graph.py | 98 +++ .../verbs/graph/embed/strategies/__init__.py | 4 + .../graph/embed/strategies/node_2_vec.py | 34 + .../index/verbs/graph/embed/typing.py | 12 + .../index/verbs/graph/layout/__init__.py | 8 + .../index/verbs/graph/layout/layout_graph.py | 139 ++++ .../verbs/graph/layout/methods/__init__.py | 4 + .../index/verbs/graph/layout/methods/umap.py | 82 +++ .../index/verbs/graph/layout/methods/zero.py | 63 ++ .../index/verbs/graph/merge/__init__.py | 8 + .../index/verbs/graph/merge/defaults.py | 21 + .../index/verbs/graph/merge/merge_graphs.py | 217 ++++++ .../index/verbs/graph/merge/typing.py | 49 ++ .../index/verbs/graph/report/__init__.py | 25 + .../graph/report/create_community_reports.py | 131 ++++ .../graph/report/prepare_community_reports.py | 187 +++++ .../prepare_community_reports_claims.py | 50 ++ .../report/prepare_community_reports_edges.py | 48 ++ .../report/prepare_community_reports_nodes.py | 46 ++ .../report/restore_community_hierarchy.py | 78 ++ .../verbs/graph/report/strategies/__init__.py | 4 + .../strategies/graph_intelligence/__init__.py | 8 + .../strategies/graph_intelligence/defaults.py | 26 + .../run_graph_intelligence.py | 99 +++ .../verbs/graph/report/strategies/typing.py | 52 ++ func-app/graphrag/index/verbs/graph/unpack.py | 107 +++ .../index/verbs/overrides/__init__.py | 10 + .../index/verbs/overrides/aggregate.py | 90 +++ .../graphrag/index/verbs/overrides/concat.py | 27 + .../graphrag/index/verbs/overrides/merge.py | 78 ++ func-app/graphrag/index/verbs/snapshot.py | 30 + .../graphrag/index/verbs/snapshot_rows.py | 86 +++ func-app/graphrag/index/verbs/spread_json.py | 55 ++ .../graphrag/index/verbs/text/__init__.py | 18 + .../index/verbs/text/chunk/__init__.py | 8 + .../verbs/text/chunk/strategies/__init__.py | 4 + .../verbs/text/chunk/strategies/sentence.py | 26 + .../verbs/text/chunk/strategies/tokens.py | 81 +++ .../verbs/text/chunk/strategies/typing.py | 17 + .../index/verbs/text/chunk/text_chunk.py | 162 +++++ .../graphrag/index/verbs/text/chunk/typing.py | 19 + .../index/verbs/text/embed/__init__.py | 8 + .../verbs/text/embed/strategies/__init__.py | 4 + .../index/verbs/text/embed/strategies/mock.py | 34 + .../verbs/text/embed/strategies/openai.py | 181 +++++ .../verbs/text/embed/strategies/typing.py | 29 + .../index/verbs/text/embed/text_embed.py | 269 +++++++ .../index/verbs/text/replace/__init__.py | 8 + .../index/verbs/text/replace/replace.py | 47 ++ .../index/verbs/text/replace/typing.py | 14 + func-app/graphrag/index/verbs/text/split.py | 54 ++ .../index/verbs/text/translate/__init__.py | 8 + .../text/translate/strategies/__init__.py | 9 + .../text/translate/strategies/defaults.py | 8 + .../verbs/text/translate/strategies/mock.py | 28 + .../verbs/text/translate/strategies/openai.py | 93 +++ .../verbs/text/translate/strategies/typing.py | 25 + .../verbs/text/translate/text_translate.py | 120 +++ func-app/graphrag/index/verbs/unzip.py | 25 + func-app/graphrag/index/verbs/zip.py | 51 ++ func-app/graphrag/index/workflows/__init__.py | 25 + .../index/workflows/default_workflows.py | 121 +++ func-app/graphrag/index/workflows/load.py | 171 +++++ func-app/graphrag/index/workflows/typing.py | 33 + .../graphrag/index/workflows/v1/__init__.py | 4 + .../workflows/v1/create_base_documents.py | 105 +++ .../workflows/v1/create_base_entity_graph.py | 91 +++ .../v1/create_base_extracted_entities.py | 95 +++ .../workflows/v1/create_base_text_units.py | 112 +++ .../workflows/v1/create_final_communities.py | 172 +++++ .../v1/create_final_community_reports.py | 133 ++++ .../workflows/v1/create_final_covariates.py | 90 +++ .../workflows/v1/create_final_documents.py | 41 ++ .../workflows/v1/create_final_entities.py | 133 ++++ .../index/workflows/v1/create_final_nodes.py | 116 +++ .../v1/create_final_relationships.py | 94 +++ .../workflows/v1/create_final_text_units.py | 161 ++++ .../v1/create_summarized_entities.py | 47 ++ .../v1/join_text_units_to_covariate_ids.py | 44 ++ .../v1/join_text_units_to_entity_ids.py | 50 ++ .../v1/join_text_units_to_relationship_ids.py | 55 ++ func-app/graphrag/llm/__init__.py | 91 +++ func-app/graphrag/llm/base/__init__.py | 10 + .../graphrag/llm/base/_create_cache_key.py | 43 ++ func-app/graphrag/llm/base/base_llm.py | 65 ++ func-app/graphrag/llm/base/caching_llm.py | 109 +++ .../graphrag/llm/base/rate_limiting_llm.py | 208 ++++++ func-app/graphrag/llm/errors.py | 12 + func-app/graphrag/llm/limiting/__init__.py | 18 + .../llm/limiting/composite_limiter.py | 26 + .../graphrag/llm/limiting/create_limiters.py | 29 + func-app/graphrag/llm/limiting/llm_limiter.py | 19 + .../graphrag/llm/limiting/noop_llm_limiter.py | 19 + .../graphrag/llm/limiting/tpm_rpm_limiter.py | 34 + func-app/graphrag/llm/mock/__init__.py | 12 + func-app/graphrag/llm/mock/mock_chat_llm.py | 52 ++ .../graphrag/llm/mock/mock_completion_llm.py | 37 + func-app/graphrag/llm/openai/__init__.py | 28 + func-app/graphrag/llm/openai/_prompts.py | 39 + .../llm/openai/create_openai_client.py | 65 ++ func-app/graphrag/llm/openai/factories.py | 140 ++++ .../graphrag/llm/openai/json_parsing_llm.py | 38 + .../graphrag/llm/openai/openai_chat_llm.py | 148 ++++ .../llm/openai/openai_completion_llm.py | 43 ++ .../llm/openai/openai_configuration.py | 288 ++++++++ .../llm/openai/openai_embeddings_llm.py | 40 + .../llm/openai/openai_history_tracking_llm.py | 42 ++ .../llm/openai/openai_token_replacing_llm.py | 37 + func-app/graphrag/llm/openai/types.py | 11 + func-app/graphrag/llm/openai/utils.py | 160 ++++ func-app/graphrag/llm/types/__init__.py | 46 ++ func-app/graphrag/llm/types/llm.py | 28 + func-app/graphrag/llm/types/llm_cache.py | 22 + func-app/graphrag/llm/types/llm_callbacks.py | 20 + func-app/graphrag/llm/types/llm_config.py | 35 + .../llm/types/llm_invocation_result.py | 35 + func-app/graphrag/llm/types/llm_io.py | 50 ++ func-app/graphrag/llm/types/llm_types.py | 16 + func-app/graphrag/model/__init__.py | 31 + func-app/graphrag/model/community.py | 54 ++ func-app/graphrag/model/community_report.py | 64 ++ func-app/graphrag/model/covariate.py | 61 ++ func-app/graphrag/model/document.py | 64 ++ func-app/graphrag/model/entity.py | 79 ++ func-app/graphrag/model/identified.py | 17 + func-app/graphrag/model/named.py | 16 + func-app/graphrag/model/relationship.py | 65 ++ func-app/graphrag/model/text_unit.py | 67 ++ func-app/graphrag/model/types.py | 8 + func-app/graphrag/prompt_tune/__init__.py | 4 + func-app/graphrag/prompt_tune/__main__.py | 148 ++++ func-app/graphrag/prompt_tune/cli.py | 272 +++++++ .../prompt_tune/generator/__init__.py | 30 + .../generator/community_report_rating.py | 35 + .../community_report_summarization.py | 48 ++ .../generator/community_reporter_role.py | 35 + .../prompt_tune/generator/defaults.py | 10 + .../graphrag/prompt_tune/generator/domain.py | 27 + .../generator/entity_extraction_prompt.py | 107 +++ .../generator/entity_relationship.py | 65 ++ .../generator/entity_summarization_prompt.py | 36 + .../prompt_tune/generator/entity_types.py | 45 ++ .../prompt_tune/generator/language.py | 27 + .../graphrag/prompt_tune/generator/persona.py | 27 + .../graphrag/prompt_tune/loader/__init__.py | 14 + .../graphrag/prompt_tune/loader/config.py | 43 ++ func-app/graphrag/prompt_tune/loader/input.py | 110 +++ .../graphrag/prompt_tune/prompt/__init__.py | 32 + .../prompt/community_report_rating.py | 132 ++++ .../prompt/community_reporter_role.py | 20 + .../graphrag/prompt_tune/prompt/domain.py | 12 + .../prompt_tune/prompt/entity_relationship.py | 355 +++++++++ .../prompt_tune/prompt/entity_types.py | 89 +++ .../graphrag/prompt_tune/prompt/language.py | 12 + .../graphrag/prompt_tune/prompt/persona.py | 13 + .../graphrag/prompt_tune/template/__init__.py | 24 + .../community_report_summarization.py | 95 +++ .../prompt_tune/template/entity_extraction.py | 141 ++++ .../template/entity_summarization.py | 22 + func-app/graphrag/query/__init__.py | 4 + func-app/graphrag/query/__main__.py | 133 ++++ func-app/graphrag/query/cli.py | 472 ++++++++++++ .../query/context_builder/__init__.py | 4 + .../query/context_builder/builders.py | 35 + .../context_builder/community_context.py | 253 +++++++ .../context_builder/conversation_history.py | 212 ++++++ .../context_builder/entity_extraction.py | 187 +++++ .../query/context_builder/local_context.py | 360 +++++++++ .../query/context_builder/source_context.py | 110 +++ func-app/graphrag/query/factories.py | 211 ++++++ func-app/graphrag/query/indexer_adapters.py | 159 ++++ func-app/graphrag/query/input/__init__.py | 4 + .../graphrag/query/input/loaders/__init__.py | 4 + func-app/graphrag/query/input/loaders/dfs.py | 340 +++++++++ .../graphrag/query/input/loaders/utils.py | 245 +++++++ .../query/input/retrieval/__init__.py | 4 + .../input/retrieval/community_reports.py | 74 ++ .../query/input/retrieval/covariates.py | 52 ++ .../query/input/retrieval/entities.py | 93 +++ .../query/input/retrieval/relationships.py | 217 ++++++ .../query/input/retrieval/text_units.py | 52 ++ func-app/graphrag/query/llm/__init__.py | 4 + func-app/graphrag/query/llm/base.py | 54 ++ func-app/graphrag/query/llm/oai/__init__.py | 21 + func-app/graphrag/query/llm/oai/base.py | 187 +++++ .../graphrag/query/llm/oai/chat_openai.py | 206 ++++++ func-app/graphrag/query/llm/oai/embedding.py | 182 +++++ func-app/graphrag/query/llm/oai/openai.py | 187 +++++ func-app/graphrag/query/llm/oai/typing.py | 23 + func-app/graphrag/query/llm/text_utils.py | 42 ++ func-app/graphrag/query/progress.py | 43 ++ .../graphrag/query/question_gen/__init__.py | 4 + func-app/graphrag/query/question_gen/base.py | 65 ++ .../graphrag/query/question_gen/local_gen.py | 194 +++++ .../query/question_gen/system_prompt.py | 28 + .../query/structured_search/__init__.py | 4 + .../graphrag/query/structured_search/base.py | 69 ++ .../global_search/__init__.py | 4 + .../global_search/callbacks.py | 24 + .../global_search/community_context.py | 99 +++ .../global_search/map_system_prompt.py | 82 +++ .../global_search/reduce_system_prompt.py | 88 +++ .../structured_search/global_search/search.py | 359 +++++++++ .../local_search/__init__.py | 4 + .../local_search/mixed_context.py | 533 ++++++++++++++ .../structured_search/local_search/search.py | 199 +++++ .../local_search/system_prompt.py | 69 ++ func-app/graphrag/vector_stores/__init__.py | 19 + .../graphrag/vector_stores/azure_ai_search.py | 225 ++++++ func-app/graphrag/vector_stores/base.py | 130 ++++ func-app/graphrag/vector_stores/kusto.py | 308 ++++++++ func-app/graphrag/vector_stores/lancedb.py | 152 ++++ func-app/graphrag/vector_stores/typing.py | 48 ++ func-app/host.json | 15 + func-app/prompts/claim_extraction.txt | 52 ++ func-app/prompts/community_report.txt | 146 ++++ func-app/prompts/entity_extraction.txt | 99 +++ func-app/prompts/summarize_descriptions.txt | 13 + func-app/requirements.txt | 142 ++++ func-app/settings/settings.yaml | 152 ++++ 417 files changed, 30442 insertions(+) create mode 100644 func-app/.gitignore create mode 100644 func-app/.vscode/launch.json create mode 100644 func-app/.vscode/settings.json create mode 100644 func-app/.vscode/tasks.json create mode 100644 func-app/common/graph_db_client.py create mode 100644 func-app/function_app.py create mode 100644 func-app/graphrag/__init__.py create mode 100644 func-app/graphrag/common/blob_storage_client.py create mode 100644 func-app/graphrag/common/config/storage.py create mode 100644 func-app/graphrag/common/graph_db_client.py create mode 100644 func-app/graphrag/common/kusto_db_client.py create mode 100644 func-app/graphrag/common/progress/__init__.py create mode 100644 func-app/graphrag/common/progress/rich.py create mode 100644 func-app/graphrag/common/progress/types.py create mode 100644 func-app/graphrag/common/storage/__init__.py create mode 100644 func-app/graphrag/common/storage/blob_pipeline_storage.py create mode 100644 func-app/graphrag/common/storage/file_pipeline_storage.py create mode 100644 func-app/graphrag/common/storage/load_storage.py create mode 100644 func-app/graphrag/common/storage/memory_pipeline_storage.py create mode 100644 func-app/graphrag/common/storage/typing.py create mode 100644 func-app/graphrag/common/utils/common_utils.py create mode 100644 func-app/graphrag/common/utils/context_utils.py create mode 100644 func-app/graphrag/config/__init__.py create mode 100644 func-app/graphrag/config/create_graphrag_config.py create mode 100644 func-app/graphrag/config/defaults.py create mode 100644 func-app/graphrag/config/enums.py create mode 100644 func-app/graphrag/config/environment_reader.py create mode 100644 func-app/graphrag/config/errors.py create mode 100644 func-app/graphrag/config/input_models/__init__.py create mode 100644 func-app/graphrag/config/input_models/cache_config_input.py create mode 100644 func-app/graphrag/config/input_models/chunking_config_input.py create mode 100644 func-app/graphrag/config/input_models/claim_extraction_config_input.py create mode 100644 func-app/graphrag/config/input_models/cluster_graph_config_input.py create mode 100644 func-app/graphrag/config/input_models/community_reports_config_input.py create mode 100644 func-app/graphrag/config/input_models/embed_graph_config_input.py create mode 100644 func-app/graphrag/config/input_models/entity_extraction_config_input.py create mode 100644 func-app/graphrag/config/input_models/global_search_config_input.py create mode 100644 func-app/graphrag/config/input_models/graphrag_config_input.py create mode 100644 func-app/graphrag/config/input_models/input_config_input.py create mode 100644 func-app/graphrag/config/input_models/llm_config_input.py create mode 100644 func-app/graphrag/config/input_models/llm_parameters_input.py create mode 100644 func-app/graphrag/config/input_models/local_search_config_input.py create mode 100644 func-app/graphrag/config/input_models/parallelization_parameters_input.py create mode 100644 func-app/graphrag/config/input_models/query_context_config_input.py create mode 100644 func-app/graphrag/config/input_models/reporting_config_input.py create mode 100644 func-app/graphrag/config/input_models/snapshots_config_input.py create mode 100644 func-app/graphrag/config/input_models/storage_config_input.py create mode 100644 func-app/graphrag/config/input_models/summarize_descriptions_config_input.py create mode 100644 func-app/graphrag/config/input_models/text_embedding_config_input.py create mode 100644 func-app/graphrag/config/input_models/umap_config_input.py create mode 100644 func-app/graphrag/config/models/__init__.py create mode 100644 func-app/graphrag/config/models/cache_config.py create mode 100644 func-app/graphrag/config/models/chunking_config.py create mode 100644 func-app/graphrag/config/models/claim_extraction_config.py create mode 100644 func-app/graphrag/config/models/cluster_graph_config.py create mode 100644 func-app/graphrag/config/models/community_reports_config.py create mode 100644 func-app/graphrag/config/models/embed_graph_config.py create mode 100644 func-app/graphrag/config/models/entity_extraction_config.py create mode 100644 func-app/graphrag/config/models/global_search_config.py create mode 100644 func-app/graphrag/config/models/graph_rag_config.py create mode 100644 func-app/graphrag/config/models/graphdb_config.py create mode 100644 func-app/graphrag/config/models/input_config.py create mode 100644 func-app/graphrag/config/models/llm_config.py create mode 100644 func-app/graphrag/config/models/llm_parameters.py create mode 100644 func-app/graphrag/config/models/local_search_config.py create mode 100644 func-app/graphrag/config/models/parallelization_parameters.py create mode 100644 func-app/graphrag/config/models/query_context_config.py create mode 100644 func-app/graphrag/config/models/reporting_config.py create mode 100644 func-app/graphrag/config/models/snapshots_config.py create mode 100644 func-app/graphrag/config/models/storage_config.py create mode 100644 func-app/graphrag/config/models/summarize_descriptions_config.py create mode 100644 func-app/graphrag/config/models/text_embedding_config.py create mode 100644 func-app/graphrag/config/models/umap_config.py create mode 100644 func-app/graphrag/config/read_dotenv.py create mode 100644 func-app/graphrag/index/__init__.py create mode 100644 func-app/graphrag/index/__main__.py create mode 100644 func-app/graphrag/index/bootstrap.py create mode 100644 func-app/graphrag/index/cache/__init__.py create mode 100644 func-app/graphrag/index/cache/json_pipeline_cache.py create mode 100644 func-app/graphrag/index/cache/load_cache.py create mode 100644 func-app/graphrag/index/cache/memory_pipeline_cache.py create mode 100644 func-app/graphrag/index/cache/noop_pipeline_cache.py create mode 100644 func-app/graphrag/index/cache/pipeline_cache.py create mode 100644 func-app/graphrag/index/cli.py create mode 100644 func-app/graphrag/index/config/__init__.py create mode 100644 func-app/graphrag/index/config/cache.py create mode 100644 func-app/graphrag/index/config/input.py create mode 100644 func-app/graphrag/index/config/pipeline.py create mode 100644 func-app/graphrag/index/config/reporting.py create mode 100644 func-app/graphrag/index/config/workflow.py create mode 100644 func-app/graphrag/index/context.py create mode 100644 func-app/graphrag/index/context_switch/contextSwitcher.py create mode 100644 func-app/graphrag/index/create_pipeline_config.py create mode 100644 func-app/graphrag/index/emit/__init__.py create mode 100644 func-app/graphrag/index/emit/csv_table_emitter.py create mode 100644 func-app/graphrag/index/emit/factories.py create mode 100644 func-app/graphrag/index/emit/graph_db_emitter.py create mode 100644 func-app/graphrag/index/emit/json_table_emitter.py create mode 100644 func-app/graphrag/index/emit/parquet_table_emitter.py create mode 100644 func-app/graphrag/index/emit/table_emitter.py create mode 100644 func-app/graphrag/index/emit/types.py create mode 100644 func-app/graphrag/index/errors.py create mode 100644 func-app/graphrag/index/graph/__init__.py create mode 100644 func-app/graphrag/index/graph/embedding/__init__.py create mode 100644 func-app/graphrag/index/graph/embedding/embedding.py create mode 100644 func-app/graphrag/index/graph/extractors/__init__.py create mode 100644 func-app/graphrag/index/graph/extractors/claims/__init__.py create mode 100644 func-app/graphrag/index/graph/extractors/claims/claim_extractor.py create mode 100644 func-app/graphrag/index/graph/extractors/claims/prompts.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/__init__.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/build_mixed_context.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/prompts.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/schemas.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/sort_context.py create mode 100644 func-app/graphrag/index/graph/extractors/community_reports/utils.py create mode 100644 func-app/graphrag/index/graph/extractors/graph/__init__.py create mode 100644 func-app/graphrag/index/graph/extractors/graph/graph_extractor.py create mode 100644 func-app/graphrag/index/graph/extractors/graph/prompts.py create mode 100644 func-app/graphrag/index/graph/extractors/summarize/__init__.py create mode 100644 func-app/graphrag/index/graph/extractors/summarize/description_summary_extractor.py create mode 100644 func-app/graphrag/index/graph/extractors/summarize/prompts.py create mode 100644 func-app/graphrag/index/graph/utils/__init__.py create mode 100644 func-app/graphrag/index/graph/utils/normalize_node_names.py create mode 100644 func-app/graphrag/index/graph/utils/stable_lcc.py create mode 100644 func-app/graphrag/index/graph/visualization/__init__.py create mode 100644 func-app/graphrag/index/graph/visualization/compute_umap_positions.py create mode 100644 func-app/graphrag/index/graph/visualization/typing.py create mode 100644 func-app/graphrag/index/init_content.py create mode 100644 func-app/graphrag/index/input/__init__.py create mode 100644 func-app/graphrag/index/input/csv.py create mode 100644 func-app/graphrag/index/input/load_input.py create mode 100644 func-app/graphrag/index/input/text.py create mode 100644 func-app/graphrag/index/llm/__init__.py create mode 100644 func-app/graphrag/index/llm/load_llm.py create mode 100644 func-app/graphrag/index/llm/types.py create mode 100644 func-app/graphrag/index/load_pipeline_config.py create mode 100644 func-app/graphrag/index/py.typed create mode 100644 func-app/graphrag/index/reporting/__init__.py create mode 100644 func-app/graphrag/index/reporting/blob_workflow_callbacks.py create mode 100644 func-app/graphrag/index/reporting/console_workflow_callbacks.py create mode 100644 func-app/graphrag/index/reporting/file_workflow_callbacks.py create mode 100644 func-app/graphrag/index/reporting/load_pipeline_reporter.py create mode 100644 func-app/graphrag/index/reporting/progress_workflow_callbacks.py create mode 100644 func-app/graphrag/index/run.py create mode 100644 func-app/graphrag/index/text_splitting/__init__.py create mode 100644 func-app/graphrag/index/text_splitting/check_token_limit.py create mode 100644 func-app/graphrag/index/text_splitting/text_splitting.py create mode 100644 func-app/graphrag/index/typing.py create mode 100644 func-app/graphrag/index/utils/__init__.py create mode 100644 func-app/graphrag/index/utils/dataframes.py create mode 100644 func-app/graphrag/index/utils/dicts.py create mode 100644 func-app/graphrag/index/utils/ds_util.py create mode 100644 func-app/graphrag/index/utils/hashing.py create mode 100644 func-app/graphrag/index/utils/is_null.py create mode 100644 func-app/graphrag/index/utils/load_graph.py create mode 100644 func-app/graphrag/index/utils/rate_limiter.py create mode 100644 func-app/graphrag/index/utils/string.py create mode 100644 func-app/graphrag/index/utils/tokens.py create mode 100644 func-app/graphrag/index/utils/topological_sort.py create mode 100644 func-app/graphrag/index/utils/uuid.py create mode 100644 func-app/graphrag/index/verbs/__init__.py create mode 100644 func-app/graphrag/index/verbs/covariates/__init__.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/__init__.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/__init__.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/defaults.py create mode 100644 func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/run_gi_extract_claims.py create mode 100644 func-app/graphrag/index/verbs/covariates/typing.py create mode 100644 func-app/graphrag/index/verbs/entities/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/entity_extract.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/defaults.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/run_graph_intelligence.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/nltk.py create mode 100644 func-app/graphrag/index/verbs/entities/extraction/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/description_summarize.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/__init__.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/defaults.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/run_graph_intelligence.py create mode 100644 func-app/graphrag/index/verbs/entities/summarize/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/genid.py create mode 100644 func-app/graphrag/index/verbs/graph/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/clustering/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/clustering/cluster_graph.py create mode 100644 func-app/graphrag/index/verbs/graph/clustering/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/clustering/strategies/leiden.py create mode 100644 func-app/graphrag/index/verbs/graph/clustering/typing.py create mode 100644 func-app/graphrag/index/verbs/graph/compute_edge_combined_degree.py create mode 100644 func-app/graphrag/index/verbs/graph/create.py create mode 100644 func-app/graphrag/index/verbs/graph/embed/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/embed/embed_graph.py create mode 100644 func-app/graphrag/index/verbs/graph/embed/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/embed/strategies/node_2_vec.py create mode 100644 func-app/graphrag/index/verbs/graph/embed/typing.py create mode 100644 func-app/graphrag/index/verbs/graph/layout/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/layout/layout_graph.py create mode 100644 func-app/graphrag/index/verbs/graph/layout/methods/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/layout/methods/umap.py create mode 100644 func-app/graphrag/index/verbs/graph/layout/methods/zero.py create mode 100644 func-app/graphrag/index/verbs/graph/merge/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/merge/defaults.py create mode 100644 func-app/graphrag/index/verbs/graph/merge/merge_graphs.py create mode 100644 func-app/graphrag/index/verbs/graph/merge/typing.py create mode 100644 func-app/graphrag/index/verbs/graph/report/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/report/create_community_reports.py create mode 100644 func-app/graphrag/index/verbs/graph/report/prepare_community_reports.py create mode 100644 func-app/graphrag/index/verbs/graph/report/prepare_community_reports_claims.py create mode 100644 func-app/graphrag/index/verbs/graph/report/prepare_community_reports_edges.py create mode 100644 func-app/graphrag/index/verbs/graph/report/prepare_community_reports_nodes.py create mode 100644 func-app/graphrag/index/verbs/graph/report/restore_community_hierarchy.py create mode 100644 func-app/graphrag/index/verbs/graph/report/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/__init__.py create mode 100644 func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py create mode 100644 func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/run_graph_intelligence.py create mode 100644 func-app/graphrag/index/verbs/graph/report/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/graph/unpack.py create mode 100644 func-app/graphrag/index/verbs/overrides/__init__.py create mode 100644 func-app/graphrag/index/verbs/overrides/aggregate.py create mode 100644 func-app/graphrag/index/verbs/overrides/concat.py create mode 100644 func-app/graphrag/index/verbs/overrides/merge.py create mode 100644 func-app/graphrag/index/verbs/snapshot.py create mode 100644 func-app/graphrag/index/verbs/snapshot_rows.py create mode 100644 func-app/graphrag/index/verbs/spread_json.py create mode 100644 func-app/graphrag/index/verbs/text/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/strategies/sentence.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/strategies/tokens.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/text_chunk.py create mode 100644 func-app/graphrag/index/verbs/text/chunk/typing.py create mode 100644 func-app/graphrag/index/verbs/text/embed/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/embed/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/embed/strategies/mock.py create mode 100644 func-app/graphrag/index/verbs/text/embed/strategies/openai.py create mode 100644 func-app/graphrag/index/verbs/text/embed/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/text/embed/text_embed.py create mode 100644 func-app/graphrag/index/verbs/text/replace/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/replace/replace.py create mode 100644 func-app/graphrag/index/verbs/text/replace/typing.py create mode 100644 func-app/graphrag/index/verbs/text/split.py create mode 100644 func-app/graphrag/index/verbs/text/translate/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/translate/strategies/__init__.py create mode 100644 func-app/graphrag/index/verbs/text/translate/strategies/defaults.py create mode 100644 func-app/graphrag/index/verbs/text/translate/strategies/mock.py create mode 100644 func-app/graphrag/index/verbs/text/translate/strategies/openai.py create mode 100644 func-app/graphrag/index/verbs/text/translate/strategies/typing.py create mode 100644 func-app/graphrag/index/verbs/text/translate/text_translate.py create mode 100644 func-app/graphrag/index/verbs/unzip.py create mode 100644 func-app/graphrag/index/verbs/zip.py create mode 100644 func-app/graphrag/index/workflows/__init__.py create mode 100644 func-app/graphrag/index/workflows/default_workflows.py create mode 100644 func-app/graphrag/index/workflows/load.py create mode 100644 func-app/graphrag/index/workflows/typing.py create mode 100644 func-app/graphrag/index/workflows/v1/__init__.py create mode 100644 func-app/graphrag/index/workflows/v1/create_base_documents.py create mode 100644 func-app/graphrag/index/workflows/v1/create_base_entity_graph.py create mode 100644 func-app/graphrag/index/workflows/v1/create_base_extracted_entities.py create mode 100644 func-app/graphrag/index/workflows/v1/create_base_text_units.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_communities.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_community_reports.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_covariates.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_documents.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_entities.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_nodes.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_relationships.py create mode 100644 func-app/graphrag/index/workflows/v1/create_final_text_units.py create mode 100644 func-app/graphrag/index/workflows/v1/create_summarized_entities.py create mode 100644 func-app/graphrag/index/workflows/v1/join_text_units_to_covariate_ids.py create mode 100644 func-app/graphrag/index/workflows/v1/join_text_units_to_entity_ids.py create mode 100644 func-app/graphrag/index/workflows/v1/join_text_units_to_relationship_ids.py create mode 100644 func-app/graphrag/llm/__init__.py create mode 100644 func-app/graphrag/llm/base/__init__.py create mode 100644 func-app/graphrag/llm/base/_create_cache_key.py create mode 100644 func-app/graphrag/llm/base/base_llm.py create mode 100644 func-app/graphrag/llm/base/caching_llm.py create mode 100644 func-app/graphrag/llm/base/rate_limiting_llm.py create mode 100644 func-app/graphrag/llm/errors.py create mode 100644 func-app/graphrag/llm/limiting/__init__.py create mode 100644 func-app/graphrag/llm/limiting/composite_limiter.py create mode 100644 func-app/graphrag/llm/limiting/create_limiters.py create mode 100644 func-app/graphrag/llm/limiting/llm_limiter.py create mode 100644 func-app/graphrag/llm/limiting/noop_llm_limiter.py create mode 100644 func-app/graphrag/llm/limiting/tpm_rpm_limiter.py create mode 100644 func-app/graphrag/llm/mock/__init__.py create mode 100644 func-app/graphrag/llm/mock/mock_chat_llm.py create mode 100644 func-app/graphrag/llm/mock/mock_completion_llm.py create mode 100644 func-app/graphrag/llm/openai/__init__.py create mode 100644 func-app/graphrag/llm/openai/_prompts.py create mode 100644 func-app/graphrag/llm/openai/create_openai_client.py create mode 100644 func-app/graphrag/llm/openai/factories.py create mode 100644 func-app/graphrag/llm/openai/json_parsing_llm.py create mode 100644 func-app/graphrag/llm/openai/openai_chat_llm.py create mode 100644 func-app/graphrag/llm/openai/openai_completion_llm.py create mode 100644 func-app/graphrag/llm/openai/openai_configuration.py create mode 100644 func-app/graphrag/llm/openai/openai_embeddings_llm.py create mode 100644 func-app/graphrag/llm/openai/openai_history_tracking_llm.py create mode 100644 func-app/graphrag/llm/openai/openai_token_replacing_llm.py create mode 100644 func-app/graphrag/llm/openai/types.py create mode 100644 func-app/graphrag/llm/openai/utils.py create mode 100644 func-app/graphrag/llm/types/__init__.py create mode 100644 func-app/graphrag/llm/types/llm.py create mode 100644 func-app/graphrag/llm/types/llm_cache.py create mode 100644 func-app/graphrag/llm/types/llm_callbacks.py create mode 100644 func-app/graphrag/llm/types/llm_config.py create mode 100644 func-app/graphrag/llm/types/llm_invocation_result.py create mode 100644 func-app/graphrag/llm/types/llm_io.py create mode 100644 func-app/graphrag/llm/types/llm_types.py create mode 100644 func-app/graphrag/model/__init__.py create mode 100644 func-app/graphrag/model/community.py create mode 100644 func-app/graphrag/model/community_report.py create mode 100644 func-app/graphrag/model/covariate.py create mode 100644 func-app/graphrag/model/document.py create mode 100644 func-app/graphrag/model/entity.py create mode 100644 func-app/graphrag/model/identified.py create mode 100644 func-app/graphrag/model/named.py create mode 100644 func-app/graphrag/model/relationship.py create mode 100644 func-app/graphrag/model/text_unit.py create mode 100644 func-app/graphrag/model/types.py create mode 100644 func-app/graphrag/prompt_tune/__init__.py create mode 100644 func-app/graphrag/prompt_tune/__main__.py create mode 100644 func-app/graphrag/prompt_tune/cli.py create mode 100644 func-app/graphrag/prompt_tune/generator/__init__.py create mode 100644 func-app/graphrag/prompt_tune/generator/community_report_rating.py create mode 100644 func-app/graphrag/prompt_tune/generator/community_report_summarization.py create mode 100644 func-app/graphrag/prompt_tune/generator/community_reporter_role.py create mode 100644 func-app/graphrag/prompt_tune/generator/defaults.py create mode 100644 func-app/graphrag/prompt_tune/generator/domain.py create mode 100644 func-app/graphrag/prompt_tune/generator/entity_extraction_prompt.py create mode 100644 func-app/graphrag/prompt_tune/generator/entity_relationship.py create mode 100644 func-app/graphrag/prompt_tune/generator/entity_summarization_prompt.py create mode 100644 func-app/graphrag/prompt_tune/generator/entity_types.py create mode 100644 func-app/graphrag/prompt_tune/generator/language.py create mode 100644 func-app/graphrag/prompt_tune/generator/persona.py create mode 100644 func-app/graphrag/prompt_tune/loader/__init__.py create mode 100644 func-app/graphrag/prompt_tune/loader/config.py create mode 100644 func-app/graphrag/prompt_tune/loader/input.py create mode 100644 func-app/graphrag/prompt_tune/prompt/__init__.py create mode 100644 func-app/graphrag/prompt_tune/prompt/community_report_rating.py create mode 100644 func-app/graphrag/prompt_tune/prompt/community_reporter_role.py create mode 100644 func-app/graphrag/prompt_tune/prompt/domain.py create mode 100644 func-app/graphrag/prompt_tune/prompt/entity_relationship.py create mode 100644 func-app/graphrag/prompt_tune/prompt/entity_types.py create mode 100644 func-app/graphrag/prompt_tune/prompt/language.py create mode 100644 func-app/graphrag/prompt_tune/prompt/persona.py create mode 100644 func-app/graphrag/prompt_tune/template/__init__.py create mode 100644 func-app/graphrag/prompt_tune/template/community_report_summarization.py create mode 100644 func-app/graphrag/prompt_tune/template/entity_extraction.py create mode 100644 func-app/graphrag/prompt_tune/template/entity_summarization.py create mode 100644 func-app/graphrag/query/__init__.py create mode 100644 func-app/graphrag/query/__main__.py create mode 100644 func-app/graphrag/query/cli.py create mode 100644 func-app/graphrag/query/context_builder/__init__.py create mode 100644 func-app/graphrag/query/context_builder/builders.py create mode 100644 func-app/graphrag/query/context_builder/community_context.py create mode 100644 func-app/graphrag/query/context_builder/conversation_history.py create mode 100644 func-app/graphrag/query/context_builder/entity_extraction.py create mode 100644 func-app/graphrag/query/context_builder/local_context.py create mode 100644 func-app/graphrag/query/context_builder/source_context.py create mode 100644 func-app/graphrag/query/factories.py create mode 100644 func-app/graphrag/query/indexer_adapters.py create mode 100644 func-app/graphrag/query/input/__init__.py create mode 100644 func-app/graphrag/query/input/loaders/__init__.py create mode 100644 func-app/graphrag/query/input/loaders/dfs.py create mode 100644 func-app/graphrag/query/input/loaders/utils.py create mode 100644 func-app/graphrag/query/input/retrieval/__init__.py create mode 100644 func-app/graphrag/query/input/retrieval/community_reports.py create mode 100644 func-app/graphrag/query/input/retrieval/covariates.py create mode 100644 func-app/graphrag/query/input/retrieval/entities.py create mode 100644 func-app/graphrag/query/input/retrieval/relationships.py create mode 100644 func-app/graphrag/query/input/retrieval/text_units.py create mode 100644 func-app/graphrag/query/llm/__init__.py create mode 100644 func-app/graphrag/query/llm/base.py create mode 100644 func-app/graphrag/query/llm/oai/__init__.py create mode 100644 func-app/graphrag/query/llm/oai/base.py create mode 100644 func-app/graphrag/query/llm/oai/chat_openai.py create mode 100644 func-app/graphrag/query/llm/oai/embedding.py create mode 100644 func-app/graphrag/query/llm/oai/openai.py create mode 100644 func-app/graphrag/query/llm/oai/typing.py create mode 100644 func-app/graphrag/query/llm/text_utils.py create mode 100644 func-app/graphrag/query/progress.py create mode 100644 func-app/graphrag/query/question_gen/__init__.py create mode 100644 func-app/graphrag/query/question_gen/base.py create mode 100644 func-app/graphrag/query/question_gen/local_gen.py create mode 100644 func-app/graphrag/query/question_gen/system_prompt.py create mode 100644 func-app/graphrag/query/structured_search/__init__.py create mode 100644 func-app/graphrag/query/structured_search/base.py create mode 100644 func-app/graphrag/query/structured_search/global_search/__init__.py create mode 100644 func-app/graphrag/query/structured_search/global_search/callbacks.py create mode 100644 func-app/graphrag/query/structured_search/global_search/community_context.py create mode 100644 func-app/graphrag/query/structured_search/global_search/map_system_prompt.py create mode 100644 func-app/graphrag/query/structured_search/global_search/reduce_system_prompt.py create mode 100644 func-app/graphrag/query/structured_search/global_search/search.py create mode 100644 func-app/graphrag/query/structured_search/local_search/__init__.py create mode 100644 func-app/graphrag/query/structured_search/local_search/mixed_context.py create mode 100644 func-app/graphrag/query/structured_search/local_search/search.py create mode 100644 func-app/graphrag/query/structured_search/local_search/system_prompt.py create mode 100644 func-app/graphrag/vector_stores/__init__.py create mode 100644 func-app/graphrag/vector_stores/azure_ai_search.py create mode 100644 func-app/graphrag/vector_stores/base.py create mode 100644 func-app/graphrag/vector_stores/kusto.py create mode 100644 func-app/graphrag/vector_stores/lancedb.py create mode 100644 func-app/graphrag/vector_stores/typing.py create mode 100644 func-app/host.json create mode 100644 func-app/prompts/claim_extraction.txt create mode 100644 func-app/prompts/community_report.txt create mode 100644 func-app/prompts/entity_extraction.txt create mode 100644 func-app/prompts/summarize_descriptions.txt create mode 100644 func-app/requirements.txt create mode 100644 func-app/settings/settings.yaml diff --git a/func-app/.gitignore b/func-app/.gitignore new file mode 100644 index 0000000000..c9d14718e1 --- /dev/null +++ b/func-app/.gitignore @@ -0,0 +1,49 @@ +bin +obj +csx +.vs +edge +Publish + +*.user +*.suo +*.cscfg +*.Cache +project.lock.json + +/packages +/TestResults + +/tools/NuGet.exe +/App_Data +/secrets +/data +.secrets +appsettings.json +local.settings.json + +node_modules +dist + +# Local python packages +.python_packages/ + +# Python Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Azurite artifacts +__blobstorage__ +__queuestorage__ +__azurite_db*__.json + diff --git a/func-app/.vscode/launch.json b/func-app/.vscode/launch.json new file mode 100644 index 0000000000..a90b7259e1 --- /dev/null +++ b/func-app/.vscode/launch.json @@ -0,0 +1,13 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Attach to Python Functions", + "type": "python", + "request": "attach", + "port": 7071, + "preLaunchTask": "func: host start", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/func-app/.vscode/settings.json b/func-app/.vscode/settings.json new file mode 100644 index 0000000000..80db5bb047 --- /dev/null +++ b/func-app/.vscode/settings.json @@ -0,0 +1,8 @@ +{ + "azureFunctions.deploySubpath": ".", + "azureFunctions.scmDoBuildDuringDeployment": true, + "azureFunctions.pythonVenv": ".venv", + "azureFunctions.projectLanguage": "Python", + "azureFunctions.projectRuntime": "~4", + "debug.internalConsoleOptions": "neverOpen", +} \ No newline at end of file diff --git a/func-app/.vscode/tasks.json b/func-app/.vscode/tasks.json new file mode 100644 index 0000000000..808884468c --- /dev/null +++ b/func-app/.vscode/tasks.json @@ -0,0 +1,26 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "type": "func", + "command": "host start", + "problemMatcher": "$func-watch", + "isBackground": true, + "dependsOn": "pipInstall" + }, + { + "label": "pipInstall", + "type": "shell", + "osx": { + "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" + }, + "windows": { + "command": "${config:azureFunctions.pythonVenv}\\Scripts\\python -m pip install -r requirements.txt" + }, + "linux": { + "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" + }, + "problemMatcher": [] + } + ] +} \ No newline at end of file diff --git a/func-app/common/graph_db_client.py b/func-app/common/graph_db_client.py new file mode 100644 index 0000000000..ab4503f01c --- /dev/null +++ b/func-app/common/graph_db_client.py @@ -0,0 +1,158 @@ +import os +import pandas as pd + +from graphrag.config.models.graphdb_config import GraphDBConfig +import numpy as np + +import ast + +from gremlin_python.driver import client, serializer +from azure.identity import ManagedIdentityCredential + +import time +import os +import json + +# Azure Cosmos DB Gremlin Endpoint and other constants +COSMOS_DB_SCOPE = "https://cosmos.azure.com/.default" # The scope for Cosmos DB +class GraphDBClient: + def __init__(self,graph_db_params: GraphDBConfig|None,context_id: str|None): + self.username_prefix=graph_db_params.username + token = f"{graph_db_params.account_key}" + #if(os.environ.get("ENVIRONMENT") == "AZURE"): + # credential = ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3") + # token = credential.get_token(COSMOS_DB_SCOPE) + self._client=client.Client( + url=f"{graph_db_params.gremlin_url}", + traversal_source="g", + username=self.username_prefix+"-contextid-"+context_id, + password=token, + message_serializer=serializer.GraphSONSerializersV2d0(), + ) + + def result_to_df(self,result) -> pd.DataFrame: + json_data = [] + for row in result: + json_row = row[0] + properties_dict = json_row.pop('properties') + formatted_properties={} + for k,v in properties_dict.items(): + new_val=v + if isinstance(v,list) and isinstance(v[0],dict): + new_val=v[0]['value'] + if k=='description_embedding' or k =='text_unit_ids' or k=='graph_embedding': + new_val=ast.literal_eval(new_val) + if isinstance(new_val,list): + new_val=np.array(new_val) + formatted_properties[k]=new_val + json_row.update(formatted_properties) + json_data.append(json_row) + df = pd.DataFrame(json_data) + return df + + def remove_graph(self): + self._client.submit(message=("g.V().drop()")) + + def query_vertices(self,context_id:str) -> pd.DataFrame: + result = self._client.submit( + message=( + "g.V()" + ), + ) + return self.result_to_df(result) + + def query_edges(self,context_id:str) -> pd.DataFrame: + result = self._client.submit( + message=( + "g.E()" + ), + ) + return self.result_to_df(result) + + def element_exists(self,element_type:str,element_id:int,conditions:str="")->bool: + result=self._client.submit( + message=( + element_type+ + ".has('id',prop_id)"+ + conditions+ + ".count()" + ), + bindings={ + "prop_id":element_id, + } + ) + element_count=0 + for counts in result: + element_count=counts[0] + return element_count>0 + + def write_vertices(self,data: pd.DataFrame)->None: + for row in data.itertuples(): + if self.element_exists("g.V()",row.id): + continue + else: + self._client.submit( + message=( + "g.addV('entity')" + ".property('id', prop_id)" + ".property('name', prop_name)" + ".property('type', prop_type)" + ".property('description','prop_description')" + ".property('human_readable_id', prop_human_readable_id)" + ".property('category', prop_partition_key)" + ".property(list,'description_embedding',prop_description_embedding)" + ".property(list,'graph_embedding',prop_graph_embedding)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ), + bindings={ + "prop_id": row.id, + "prop_name": row.name, + "prop_type": row.type, + "prop_description": row.description, + "prop_human_readable_id": row.human_readable_id, + "prop_partition_key": "entities", + "prop_description_embedding":json.dumps(row.description_embedding.tolist() if row.description_embedding is not None else []), + "prop_graph_embedding":json.dumps(row.graph_embedding.tolist() if row.graph_embedding is not None else []), + "prop_text_unit_ids":json.dumps(row.text_unit_ids.tolist() if row.text_unit_ids is not None else []), + }, + ) + time.sleep(5) + + + def write_edges(self,data: pd.DataFrame)->None: + for row in data.itertuples(): + if self.element_exists("g.E()",row.id): + continue + self._client.submit( + message=( + "g.V().has('name',prop_source_id)" + ".addE('connects')" + ".to(g.V().has('name',prop_target_id))" + ".property('weight',prop_weight)" + ".property(list,'text_unit_ids',prop_text_unit_ids)" + ".property('description',prop_description)" + ".property('id',prop_id)" + ".property('human_readable_id',prop_human_readable_id)" + ".property('source_degree',prop_source_degree)" + ".property('target_degree',prop_target_degree)" + ".property('rank',prop_rank)" + ".property('source',prop_source)" + ".property('target',prop_target)" + ), + bindings={ + "prop_partition_key": "entities", + "prop_source_id": row.source, + "prop_target_id": row.target, + "prop_weight": row.weight, + "prop_text_unit_ids":json.dumps(row.text_unit_ids.tolist() if row.text_unit_ids is not None else []), + "prop_description": row.description, + "prop_id": row.id, + "prop_human_readable_id": row.human_readable_id, + "prop_source_degree": row.source_degree, + "prop_target_degree": row.target_degree, + "prop_rank": row.rank, + "prop_source": row.source, + "prop_target": row.target, + }, + ) + time.sleep(5) \ No newline at end of file diff --git a/func-app/function_app.py b/func-app/function_app.py new file mode 100644 index 0000000000..38df93461f --- /dev/null +++ b/func-app/function_app.py @@ -0,0 +1,37 @@ +import azure.functions as func +import datetime +import json +import logging +import csv +import codecs +from graphrag.index.cli import index_cli + +app = func.FunctionApp() + +@app.function_name('IndexingPipelineFunc') +@app.route(route="index", auth_level=func.AuthLevel.ANONYMOUS) +def indexing(req: func.HttpRequest) -> func.HttpResponse: + logging.info('Python HTTP trigger function processed a request.') + + index_cli( + root = "", + verbose=False, + resume=False, + memprofile=False, + nocache=False, + config=None, + emit=None, + dryrun=False, + init=True, + overlay_defaults=False, + cli=True, + context_id=None, + context_operation=None, + community_level=None, + use_kusto_community_reports=None, + optimized_search=None + ) + return func.HttpResponse( + "Wow this first HTTP Function works!!!!", + status_code=200 + ) diff --git a/func-app/graphrag/__init__.py b/func-app/graphrag/__init__.py new file mode 100644 index 0000000000..a1e9b589bf --- /dev/null +++ b/func-app/graphrag/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The GraphRAG package.""" diff --git a/func-app/graphrag/common/blob_storage_client.py b/func-app/graphrag/common/blob_storage_client.py new file mode 100644 index 0000000000..78dc809579 --- /dev/null +++ b/func-app/graphrag/common/blob_storage_client.py @@ -0,0 +1,58 @@ +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient + + +class BlobStorageClient: + """The Blob-Storage implementation.""" + + _connection_string: str | None + _container_name: str + _path_prefix: str + _encoding: str + _storage_account_blob_url: str | None + + def __init__( + self, + connection_string: str | None, + container_name: str, + encoding: str | None = None, + path_prefix: str | None = None, + storage_account_blob_url: str | None = None, + ): + """Create a new BlobStorage instance.""" + if connection_string: + self._blob_service_client = BlobServiceClient.from_connection_string( + connection_string + ) + else: + if storage_account_blob_url is None: + msg = "Either connection_string or storage_account_blob_url must be provided." + raise ValueError(msg) + + self._blob_service_client = BlobServiceClient( + account_url=storage_account_blob_url, + credential=DefaultAzureCredential(), + ) + self._encoding = encoding or "utf-8" + self._container_name = container_name + self._connection_string = connection_string + self._path_prefix = path_prefix or "" + self._storage_account_blob_url = storage_account_blob_url + self._storage_account_name = ( + storage_account_blob_url.split("//")[1].split(".")[0] + if storage_account_blob_url + else None + ) + #log.info( + # "creating blob storage at container=%s, path=%s", + # self._container_name, + # self._path_prefix, + #) + + def get_blob_service_client(self): + """Get the BlobServiceClient instance.""" + return self._blob_service_client + + def get_container_client(self): + """Get the container client instance.""" + return self._blob_service_client.get_container_client(self._container_name) diff --git a/func-app/graphrag/common/config/storage.py b/func-app/graphrag/common/config/storage.py new file mode 100644 index 0000000000..023d50e249 --- /dev/null +++ b/func-app/graphrag/common/config/storage.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineStorageConfig', 'PipelineFileStorageConfig' and 'PipelineMemoryStorageConfig' models.""" + +from __future__ import annotations + +from typing import Generic, Literal, TypeVar + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +from graphrag.config.enums import StorageType + +T = TypeVar("T") + + +class PipelineStorageConfig(BaseModel, Generic[T]): + """Represent the storage configuration for the pipeline.""" + + type: T + + +class PipelineFileStorageConfig(PipelineStorageConfig[Literal[StorageType.file]]): + """Represent the file storage configuration for the pipeline.""" + + type: Literal[StorageType.file] = StorageType.file + """The type of storage.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the storage.", default=None + ) + """The base directory for the storage.""" + + +class PipelineMemoryStorageConfig(PipelineStorageConfig[Literal[StorageType.memory]]): + """Represent the memory storage configuration for the pipeline.""" + + type: Literal[StorageType.memory] = StorageType.memory + """The type of storage.""" + + +class PipelineBlobStorageConfig(PipelineStorageConfig[Literal[StorageType.blob]]): + """Represents the blob storage configuration for the pipeline.""" + + type: Literal[StorageType.blob] = StorageType.blob + """The type of storage.""" + + connection_string: str | None = pydantic_Field( + description="The blob storage connection string for the storage.", default=None + ) + """The blob storage connection string for the storage.""" + + container_name: str = pydantic_Field( + description="The container name for storage", default=None + ) + """The container name for storage.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the storage.", default=None + ) + """The base directory for the storage.""" + + storage_account_blob_url: str | None = pydantic_Field( + description="The storage account blob url.", default=None + ) + """The storage account blob url.""" + + +PipelineStorageConfigTypes = ( + PipelineFileStorageConfig | PipelineMemoryStorageConfig | PipelineBlobStorageConfig +) diff --git a/func-app/graphrag/common/graph_db_client.py b/func-app/graphrag/common/graph_db_client.py new file mode 100644 index 0000000000..35b70dd385 --- /dev/null +++ b/func-app/graphrag/common/graph_db_client.py @@ -0,0 +1 @@ +# create Gremlin and cosmos db clients by reading settings from settings.yaml \ No newline at end of file diff --git a/func-app/graphrag/common/kusto_db_client.py b/func-app/graphrag/common/kusto_db_client.py new file mode 100644 index 0000000000..413a47341e --- /dev/null +++ b/func-app/graphrag/common/kusto_db_client.py @@ -0,0 +1 @@ +# create Gremlin and kusto db clients by reading settings from settings.yaml \ No newline at end of file diff --git a/func-app/graphrag/common/progress/__init__.py b/func-app/graphrag/common/progress/__init__.py new file mode 100644 index 0000000000..df6a21523d --- /dev/null +++ b/func-app/graphrag/common/progress/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Progress-reporting components.""" + +from .types import NullProgressReporter, PrintProgressReporter, ProgressReporter + +__all__ = ["NullProgressReporter", "PrintProgressReporter", "ProgressReporter"] diff --git a/func-app/graphrag/common/progress/rich.py b/func-app/graphrag/common/progress/rich.py new file mode 100644 index 0000000000..362b64f0c8 --- /dev/null +++ b/func-app/graphrag/common/progress/rich.py @@ -0,0 +1,165 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Rich-based progress reporter for CLI use.""" + +# Print iterations progress +import asyncio + +from datashaper import Progress as DSProgress +from rich.console import Console, Group +from rich.live import Live +from rich.progress import Progress, TaskID, TimeElapsedColumn +from rich.spinner import Spinner +from rich.tree import Tree + +from .types import ProgressReporter + + +# https://stackoverflow.com/a/34325723 +class RichProgressReporter(ProgressReporter): + """A rich-based progress reporter for CLI use.""" + + _console: Console + _group: Group + _tree: Tree + _live: Live + _task: TaskID | None = None + _prefix: str + _transient: bool + _disposing: bool = False + _progressbar: Progress + _last_refresh: float = 0 + + def dispose(self) -> None: + """Dispose of the progress reporter.""" + self._disposing = True + self._live.stop() + + @property + def console(self) -> Console: + """Get the console.""" + return self._console + + @property + def group(self) -> Group: + """Get the group.""" + return self._group + + @property + def tree(self) -> Tree: + """Get the tree.""" + return self._tree + + @property + def live(self) -> Live: + """Get the live.""" + return self._live + + def __init__( + self, + prefix: str, + parent: "RichProgressReporter | None" = None, + transient: bool = True, + ) -> None: + """Create a new rich-based progress reporter.""" + self._prefix = prefix + + if parent is None: + console = Console() + group = Group(Spinner("dots", prefix), fit=True) + tree = Tree(group) + live = Live( + tree, console=console, refresh_per_second=1, vertical_overflow="crop" + ) + live.start() + + self._console = console + self._group = group + self._tree = tree + self._live = live + self._transient = False + else: + self._console = parent.console + self._group = parent.group + progress_columns = [*Progress.get_default_columns(), TimeElapsedColumn()] + self._progressbar = Progress( + *progress_columns, console=self._console, transient=transient + ) + + tree = Tree(prefix) + tree.add(self._progressbar) + tree.hide_root = True + + if parent is not None: + parent_tree = parent.tree + parent_tree.hide_root = False + parent_tree.add(tree) + + self._tree = tree + self._live = parent.live + self._transient = transient + + self.refresh() + + def refresh(self) -> None: + """Perform a debounced refresh.""" + now = asyncio.get_event_loop().time() + duration = now - self._last_refresh + if duration > 0.1: + self._last_refresh = now + self.force_refresh() + + def force_refresh(self) -> None: + """Force a refresh.""" + self.live.refresh() + + def stop(self) -> None: + """Stop the progress reporter.""" + self._live.stop() + + def child(self, prefix: str, transient: bool = True) -> ProgressReporter: + """Create a child progress bar.""" + return RichProgressReporter(parent=self, prefix=prefix, transient=transient) + + def error(self, message: str) -> None: + """Report an error.""" + self._console.print(f"❌ [red]{message}[/red]") + + def warning(self, message: str) -> None: + """Report a warning.""" + self._console.print(f"⚠️ [yellow]{message}[/yellow]") + + def success(self, message: str) -> None: + """Report success.""" + self._console.print(f"🚀 [green]{message}[/green]") + + def info(self, message: str) -> None: + """Report information.""" + self._console.print(message) + + def __call__(self, progress_update: DSProgress) -> None: + """Update progress.""" + if self._disposing: + return + progressbar = self._progressbar + + if self._task is None: + self._task = progressbar.add_task(self._prefix) + + progress_description = "" + if progress_update.description is not None: + progress_description = f" - {progress_update.description}" + + completed = progress_update.completed_items or progress_update.percent + total = progress_update.total_items or 1 + progressbar.update( + self._task, + completed=completed, + total=total, + description=f"{self._prefix}{progress_description}", + ) + if completed == total and self._transient: + progressbar.update(self._task, visible=False) + + self.refresh() diff --git a/func-app/graphrag/common/progress/types.py b/func-app/graphrag/common/progress/types.py new file mode 100644 index 0000000000..2912155ed1 --- /dev/null +++ b/func-app/graphrag/common/progress/types.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Types for status reporting.""" + +from abc import ABC, abstractmethod + +from datashaper import Progress + + +class ProgressReporter(ABC): + """ + Abstract base class for progress reporters. + + This is used to report workflow processing progress via mechanisms like progress-bars. + """ + + @abstractmethod + def __call__(self, update: Progress): + """Update progress.""" + + @abstractmethod + def dispose(self): + """Dispose of the progress reporter.""" + + @abstractmethod + def child(self, prefix: str, transient=True) -> "ProgressReporter": + """Create a child progress bar.""" + + @abstractmethod + def force_refresh(self) -> None: + """Force a refresh.""" + + @abstractmethod + def stop(self) -> None: + """Stop the progress reporter.""" + + @abstractmethod + def error(self, message: str) -> None: + """Report an error.""" + + @abstractmethod + def warning(self, message: str) -> None: + """Report a warning.""" + + @abstractmethod + def info(self, message: str) -> None: + """Report information.""" + + @abstractmethod + def success(self, message: str) -> None: + """Report success.""" + + +class NullProgressReporter(ProgressReporter): + """A progress reporter that does nothing.""" + + def __call__(self, update: Progress) -> None: + """Update progress.""" + + def dispose(self) -> None: + """Dispose of the progress reporter.""" + + def child(self, prefix: str, transient: bool = True) -> ProgressReporter: + """Create a child progress bar.""" + return self + + def force_refresh(self) -> None: + """Force a refresh.""" + + def stop(self) -> None: + """Stop the progress reporter.""" + + def error(self, message: str) -> None: + """Report an error.""" + + def warning(self, message: str) -> None: + """Report a warning.""" + + def info(self, message: str) -> None: + """Report information.""" + + def success(self, message: str) -> None: + """Report success.""" + + +class PrintProgressReporter(ProgressReporter): + """A progress reporter that does nothing.""" + + prefix: str + + def __init__(self, prefix: str): + """Create a new progress reporter.""" + self.prefix = prefix + print(f"\n{self.prefix}", end="") # noqa T201 + + def __call__(self, update: Progress) -> None: + """Update progress.""" + print(".", end="") # noqa T201 + + def dispose(self) -> None: + """Dispose of the progress reporter.""" + + def child(self, prefix: str, transient: bool = True) -> "ProgressReporter": + """Create a child progress bar.""" + return PrintProgressReporter(prefix) + + def stop(self) -> None: + """Stop the progress reporter.""" + + def force_refresh(self) -> None: + """Force a refresh.""" + + def error(self, message: str) -> None: + """Report an error.""" + print(f"\n{self.prefix}ERROR: {message}") # noqa T201 + + def warning(self, message: str) -> None: + """Report a warning.""" + print(f"\n{self.prefix}WARNING: {message}") # noqa T201 + + def info(self, message: str) -> None: + """Report information.""" + print(f"\n{self.prefix}INFO: {message}") # noqa T201 + + def success(self, message: str) -> None: + """Report success.""" + print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201 diff --git a/func-app/graphrag/common/storage/__init__.py b/func-app/graphrag/common/storage/__init__.py new file mode 100644 index 0000000000..7ca943db52 --- /dev/null +++ b/func-app/graphrag/common/storage/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine storage package root.""" + +from .blob_pipeline_storage import BlobPipelineStorage, create_blob_storage +from .file_pipeline_storage import FilePipelineStorage +from .load_storage import load_storage +from .memory_pipeline_storage import MemoryPipelineStorage +from .typing import PipelineStorage + +__all__ = [ + "BlobPipelineStorage", + "FilePipelineStorage", + "MemoryPipelineStorage", + "PipelineStorage", + "create_blob_storage", + "load_storage", +] diff --git a/func-app/graphrag/common/storage/blob_pipeline_storage.py b/func-app/graphrag/common/storage/blob_pipeline_storage.py new file mode 100644 index 0000000000..568ec89bcc --- /dev/null +++ b/func-app/graphrag/common/storage/blob_pipeline_storage.py @@ -0,0 +1,375 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Azure Blob Storage implementation of PipelineStorage.""" + +import logging +import re +from collections.abc import Iterator +from pathlib import Path +from typing import Any + +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from datashaper import Progress + +from graphrag.common.progress import ProgressReporter + +from .typing import PipelineStorage + +log = logging.getLogger(__name__) + + +class BlobPipelineStorage(PipelineStorage): + """The Blob-Storage implementation.""" + + _connection_string: str | None + _container_name: str + _path_prefix: str + _encoding: str + _storage_account_blob_url: str | None + + def __init__( + self, + connection_string: str | None, + container_name: str, + encoding: str | None = None, + path_prefix: str | None = None, + storage_account_blob_url: str | None = None, + overwrite: bool = False + ): + """Create a new BlobStorage instance.""" + if connection_string: + self._blob_service_client = BlobServiceClient.from_connection_string( + connection_string + ) + else: + if storage_account_blob_url is None: + msg = "Either connection_string or storage_account_blob_url must be provided." + raise ValueError(msg) + + self._blob_service_client = BlobServiceClient( + account_url=storage_account_blob_url, + credential=DefaultAzureCredential(managed_identity_client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), + ) + self._encoding = encoding or "utf-8" + self._container_name = container_name + self._connection_string = connection_string + self._overwrite = overwrite + self._path_prefix = path_prefix or "" + self._storage_account_blob_url = storage_account_blob_url + self._storage_account_name = ( + storage_account_blob_url.split("//")[1].split(".")[0] + if storage_account_blob_url + else None + ) + log.info( + "creating blob storage at container=%s, path=%s", + self._container_name, + self._path_prefix, + ) + self.create_container() + + def create_container(self) -> None: + """Create the container if it does not exist.""" + if not self.container_exists(): + container_name = self._container_name + container_names = [ + container.name + for container in self._blob_service_client.list_containers() + ] + if container_name not in container_names: + self._blob_service_client.create_container(container_name) + + def delete_container(self) -> None: + """Delete the container.""" + if self.container_exists(): + self._blob_service_client.delete_container(self._container_name) + + def container_exists(self) -> bool: + """Check if the container exists.""" + container_name = self._container_name + container_names = [ + container.name for container in self._blob_service_client.list_containers() + ] + return container_name in container_names + + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + progress: ProgressReporter | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find blobs in a container using a file pattern, as well as a custom filter function. + + Params: + base_dir: The name of the base container. + file_pattern: The file pattern to use. + file_filter: A dictionary of key-value pairs to filter the blobs. + max_count: The maximum number of blobs to return. If -1, all blobs are returned. + + Returns + ------- + An iterator of blob names and their corresponding regex matches. + """ + base_dir = base_dir or "" + + log.info( + "search container %s for files matching %s", + self._container_name, + file_pattern.pattern, + ) + + def blobname(blob_name: str) -> str: + if blob_name.startswith(self._path_prefix): + blob_name = blob_name.replace(self._path_prefix, "", 1) + if blob_name.startswith("/"): + blob_name = blob_name[1:] + return blob_name + + def item_filter(item: dict[str, Any]) -> bool: + if file_filter is None: + return True + + return all(re.match(value, item[key]) for key, value in file_filter.items()) + + try: + container_client = self._blob_service_client.get_container_client( + self._container_name + ) + all_blobs = list(container_client.list_blobs()) + + num_loaded = 0 + num_total = len(list(all_blobs)) + num_filtered = 0 + for blob in all_blobs: + match = file_pattern.match(blob.name) + if match and blob.name.startswith(base_dir): + group = match.groupdict() + if item_filter(group): + yield (blobname(blob.name), group) + num_loaded += 1 + if max_count > 0 and num_loaded >= max_count: + break + else: + num_filtered += 1 + else: + num_filtered += 1 + if progress is not None: + progress( + _create_progress_status(num_loaded, num_filtered, num_total) + ) + except Exception: + log.exception( + "Error finding blobs: base_dir=%s, file_pattern=%s, file_filter=%s", + base_dir, + file_pattern, + file_filter, + ) + raise + + async def get( + self, key: str, as_bytes: bool | None = False, encoding: str | None = None + ) -> Any: + """Get a value from the cache.""" + try: + key = self._keyname(key) + container_client = self._blob_service_client.get_container_client( + self._container_name + ) + blob_client = container_client.get_blob_client(key) + blob_data = blob_client.download_blob().readall() + if not as_bytes: + coding = encoding or "utf-8" + blob_data = blob_data.decode(coding) + except Exception: + log.exception("Error getting key %s", key) + return None + else: + return blob_data + + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: + """Set a value in the cache.""" + try: + key = self._keyname(key) + container_client = self._blob_service_client.get_container_client( + self._container_name + ) + blob_client = container_client.get_blob_client(key) + if blob_client.exists() and not self._overwrite: + ValueError("Artifacts already exists, make sure output folder is empty.") + if isinstance(value, bytes): + blob_client.upload_blob(value, overwrite=True) + else: + coding = encoding or "utf-8" + blob_client.upload_blob(value.encode(coding), overwrite=True) + except Exception: + log.exception("Error setting key %s: %s", key) + + def set_df_json(self, key: str, dataframe: Any) -> None: + """Set a json dataframe.""" + if self._connection_string is None and self._storage_account_name: + dataframe.to_json( + self._abfs_url(key), + storage_options={ + "account_name": self._storage_account_name, + "credential": DefaultAzureCredential(), + }, + orient="records", + lines=True, + force_ascii=False, + ) + else: + dataframe.to_json( + self._abfs_url(key), + storage_options={"connection_string": self._connection_string}, + orient="records", + lines=True, + force_ascii=False, + ) + + def set_df_parquet(self, key: str, dataframe: Any) -> None: + """Set a parquet dataframe.""" + if self._connection_string is None and self._storage_account_name: + dataframe.to_parquet( + self._abfs_url(key), + storage_options={ + "account_name": self._storage_account_name, + "credential": DefaultAzureCredential(), + }, + ) + else: + dataframe.to_parquet( + self._abfs_url(key), + storage_options={"connection_string": self._connection_string}, + ) + + async def has(self, key: str) -> bool: + """Check if a key exists in the cache.""" + key = self._keyname(key) + container_client = self._blob_service_client.get_container_client( + self._container_name + ) + blob_client = container_client.get_blob_client(key) + return blob_client.exists() + + async def delete(self, key: str) -> None: + """Delete a key from the cache.""" + key = self._keyname(key) + container_client = self._blob_service_client.get_container_client( + self._container_name + ) + blob_client = container_client.get_blob_client(key) + blob_client.delete_blob() + + async def clear(self) -> None: + """Clear the cache.""" + + def child(self, name: str | None) -> "PipelineStorage": + """Create a child storage instance.""" + if name is None: + return self + path = str(Path(self._path_prefix) / name) + return BlobPipelineStorage( + self._connection_string, + self._container_name, + self._encoding, + path, + self._storage_account_blob_url, + ) + + def _keyname(self, key: str) -> str: + """Get the key name.""" + return str(Path(self._path_prefix) / key) + + def _abfs_url(self, key: str) -> str: + """Get the ABFS URL.""" + path = str(Path(self._container_name) / self._path_prefix / key) + return f"abfs://{path}" + + +def create_blob_storage( + connection_string: str | None, + storage_account_blob_url: str | None, + container_name: str, + base_dir: str | None, +) -> PipelineStorage: + """Create a blob based storage.""" + log.info("Creating blob storage at %s", container_name) + if container_name is None: + msg = "No container name provided for blob storage." + raise ValueError(msg) + if connection_string is None and storage_account_blob_url is None: + msg = "No storage account blob url provided for blob storage." + raise ValueError(msg) + return BlobPipelineStorage( + connection_string, + container_name, + path_prefix=base_dir, + storage_account_blob_url=storage_account_blob_url, + ) + + +def validate_blob_container_name(container_name: str): + """ + Check if the provided blob container name is valid based on Azure rules. + + - A blob container name must be between 3 and 63 characters in length. + - Start with a letter or number + - All letters used in blob container names must be lowercase. + - Contain only letters, numbers, or the hyphen. + - Consecutive hyphens are not permitted. + - Cannot end with a hyphen. + + Args: + ----- + container_name (str) + The blob container name to be validated. + + Returns + ------- + bool: True if valid, False otherwise. + """ + # Check the length of the name + if len(container_name) < 3 or len(container_name) > 63: + return ValueError( + f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long." + ) + + # Check if the name starts with a letter or number + if not container_name[0].isalnum(): + return ValueError( + f"Container name must start with a letter or number. Starting character was {container_name[0]}." + ) + + # Check for valid characters (letters, numbers, hyphen) and lowercase letters + if not re.match("^[a-z0-9-]+$", container_name): + return ValueError( + f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}." + ) + + # Check for consecutive hyphens + if "--" in container_name: + return ValueError( + f"Container name cannot contain consecutive hyphens. Name provided was {container_name}." + ) + + # Check for hyphens at the end of the name + if container_name[-1] == "-": + return ValueError( + f"Container name cannot end with a hyphen. Name provided was {container_name}." + ) + + return True + + +def _create_progress_status( + num_loaded: int, num_filtered: int, num_total: int +) -> Progress: + return Progress( + total_items=num_total, + completed_items=num_loaded + num_filtered, + description=f"{num_loaded} files loaded ({num_filtered} filtered)", + ) diff --git a/func-app/graphrag/common/storage/file_pipeline_storage.py b/func-app/graphrag/common/storage/file_pipeline_storage.py new file mode 100644 index 0000000000..212783e41f --- /dev/null +++ b/func-app/graphrag/common/storage/file_pipeline_storage.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'FileStorage' and 'FilePipelineStorage' models.""" + +import logging +import os +import re +import shutil +from collections.abc import Iterator +from pathlib import Path +from typing import Any, cast + +import aiofiles +from aiofiles.os import remove +from aiofiles.ospath import exists +from datashaper import Progress + +from graphrag.common.progress import ProgressReporter + +from .typing import PipelineStorage + +log = logging.getLogger(__name__) + + +class FilePipelineStorage(PipelineStorage): + """File storage class definition.""" + + _root_dir: str + _encoding: str + + def __init__(self, root_dir: str | None = None, encoding: str | None = None): + """Init method definition.""" + self._root_dir = root_dir or "" + self._encoding = encoding or "utf-8" + Path(self._root_dir).mkdir(parents=True, exist_ok=True) + + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + progress: ProgressReporter | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find files in the storage using a file pattern, as well as a custom filter function.""" + + def item_filter(item: dict[str, Any]) -> bool: + if file_filter is None: + return True + + return all(re.match(value, item[key]) for key, value in file_filter.items()) + + search_path = Path(self._root_dir) / (base_dir or "") + log.info("search %s for files matching %s", search_path, file_pattern.pattern) + all_files = list(search_path.rglob("**/*")) + num_loaded = 0 + num_total = len(all_files) + num_filtered = 0 + for file in all_files: + match = file_pattern.match(f"{file}") + if match: + group = match.groupdict() + if item_filter(group): + filename = f"{file}".replace(self._root_dir, "") + if filename.startswith(os.sep): + filename = filename[1:] + yield (filename, group) + num_loaded += 1 + if max_count > 0 and num_loaded >= max_count: + break + else: + num_filtered += 1 + else: + num_filtered += 1 + if progress is not None: + progress(_create_progress_status(num_loaded, num_filtered, num_total)) + + async def get( + self, key: str, as_bytes: bool | None = False, encoding: str | None = None + ) -> Any: + """Get method definition.""" + file_path = join_path(self._root_dir, key) + + if await self.has(key): + return await self._read_file(file_path, as_bytes, encoding) + if await exists(key): + # Lookup for key, as it is pressumably a new file loaded from inputs + # and not yet written to storage + return await self._read_file(key, as_bytes, encoding) + + return None + + async def _read_file( + self, + path: str | Path, + as_bytes: bool | None = False, + encoding: str | None = None, + ) -> Any: + """Read the contents of a file.""" + read_type = "rb" if as_bytes else "r" + encoding = None if as_bytes else (encoding or self._encoding) + + async with aiofiles.open( + path, + cast(Any, read_type), + encoding=encoding, + ) as f: + return await f.read() + + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: + """Set method definition.""" + is_bytes = isinstance(value, bytes) + write_type = "wb" if is_bytes else "w" + encoding = None if is_bytes else encoding or self._encoding + os.makedirs(os.path.dirname(join_path(self._root_dir, key)), mode=777, exist_ok=True) + async with aiofiles.open( + join_path(self._root_dir, key), + cast(Any, write_type), + encoding=encoding, + ) as f: + await f.write(value) + + async def has(self, key: str) -> bool: + """Has method definition.""" + return await exists(join_path(self._root_dir, key)) + + async def delete(self, key: str) -> None: + """Delete method definition.""" + if await self.has(key): + await remove(join_path(self._root_dir, key)) + + async def clear(self) -> None: + """Clear method definition.""" + for file in Path(self._root_dir).glob("*"): + if file.is_dir(): + shutil.rmtree(file) + else: + file.unlink() + + def child(self, name: str | None) -> "PipelineStorage": + """Create a child storage instance.""" + if name is None: + return self + return FilePipelineStorage(str(Path(self._root_dir) / Path(name))) + + +def join_path(file_path: str, file_name: str) -> Path: + """Join a path and a file. Independent of the OS.""" + return Path(file_path) / Path(file_name).parent / Path(file_name).name + + +def create_file_storage(out_dir: str | None) -> PipelineStorage: + """Create a file based storage.""" + log.info("Creating file storage at %s", out_dir) + return FilePipelineStorage(out_dir) + + +def _create_progress_status( + num_loaded: int, num_filtered: int, num_total: int +) -> Progress: + return Progress( + total_items=num_total, + completed_items=num_loaded + num_filtered, + description=f"{num_loaded} files loaded ({num_filtered} filtered)", + ) diff --git a/func-app/graphrag/common/storage/load_storage.py b/func-app/graphrag/common/storage/load_storage.py new file mode 100644 index 0000000000..24a6675a04 --- /dev/null +++ b/func-app/graphrag/common/storage/load_storage.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load_storage method definition.""" + +from __future__ import annotations + +from typing import cast + +from graphrag.config import StorageType +from graphrag.common.config.storage import ( + PipelineBlobStorageConfig, + PipelineFileStorageConfig, + PipelineStorageConfig, +) + +from .blob_pipeline_storage import create_blob_storage +from .file_pipeline_storage import create_file_storage +from .memory_pipeline_storage import create_memory_storage + + +def load_storage(config: PipelineStorageConfig): + """Load the storage for a pipeline.""" + match config.type: + case StorageType.memory: + return create_memory_storage() + case StorageType.blob: + config = cast(PipelineBlobStorageConfig, config) + return create_blob_storage( + config.connection_string, + config.storage_account_blob_url, + config.container_name, + config.base_dir, + ) + case StorageType.file: + config = cast(PipelineFileStorageConfig, config) + return create_file_storage(config.base_dir) + case _: + msg = f"Unknown storage type: {config.type}" + raise ValueError(msg) diff --git a/func-app/graphrag/common/storage/memory_pipeline_storage.py b/func-app/graphrag/common/storage/memory_pipeline_storage.py new file mode 100644 index 0000000000..2d1382e0af --- /dev/null +++ b/func-app/graphrag/common/storage/memory_pipeline_storage.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'InMemoryStorage' model.""" + +from typing import Any + +from .file_pipeline_storage import FilePipelineStorage +from .typing import PipelineStorage + + +class MemoryPipelineStorage(FilePipelineStorage): + """In memory storage class definition.""" + + _storage: dict[str, Any] + + def __init__(self): + """Init method definition.""" + super().__init__(root_dir=".output") + self._storage = {} + + async def get( + self, key: str, as_bytes: bool | None = None, encoding: str | None = None + ) -> Any: + """Get the value for the given key. + + Args: + - key - The key to get the value for. + - as_bytes - Whether or not to return the value as bytes. + + Returns + ------- + - output - The value for the given key. + """ + return self._storage.get(key) or await super().get(key, as_bytes, encoding) + + async def set( + self, key: str, value: str | bytes | None, encoding: str | None = None + ) -> None: + """Set the value for the given key. + + Args: + - key - The key to set the value for. + - value - The value to set. + """ + self._storage[key] = value + + async def has(self, key: str) -> bool: + """Return True if the given key exists in the storage. + + Args: + - key - The key to check for. + + Returns + ------- + - output - True if the key exists in the storage, False otherwise. + """ + return key in self._storage or await super().has(key) + + async def delete(self, key: str) -> None: + """Delete the given key from the storage. + + Args: + - key - The key to delete. + """ + del self._storage[key] + + async def clear(self) -> None: + """Clear the storage.""" + self._storage.clear() + + def child(self, name: str | None) -> "PipelineStorage": + """Create a child storage instance.""" + return self + + +def create_memory_storage() -> PipelineStorage: + """Create memory storage.""" + return MemoryPipelineStorage() diff --git a/func-app/graphrag/common/storage/typing.py b/func-app/graphrag/common/storage/typing.py new file mode 100644 index 0000000000..c5f0de3265 --- /dev/null +++ b/func-app/graphrag/common/storage/typing.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineStorage' model.""" + +import re +from abc import ABCMeta, abstractmethod +from collections.abc import Iterator +from typing import Any + +from graphrag.common.progress import ProgressReporter + + +class PipelineStorage(metaclass=ABCMeta): + """Provide a storage interface for the pipeline. This is where the pipeline will store its output data.""" + + @abstractmethod + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + progress: ProgressReporter | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find files in the storage using a file pattern, as well as a custom filter function.""" + + @abstractmethod + async def get( + self, key: str, as_bytes: bool | None = None, encoding: str | None = None + ) -> Any: + """Get the value for the given key. + + Args: + - key - The key to get the value for. + - as_bytes - Whether or not to return the value as bytes. + + Returns + ------- + - output - The value for the given key. + """ + + @abstractmethod + async def set( + self, key: str, value: str | bytes | None, encoding: str | None = None + ) -> None: + """Set the value for the given key. + + Args: + - key - The key to set the value for. + - value - The value to set. + """ + + @abstractmethod + async def has(self, key: str) -> bool: + """Return True if the given key exists in the storage. + + Args: + - key - The key to check for. + + Returns + ------- + - output - True if the key exists in the storage, False otherwise. + """ + + @abstractmethod + async def delete(self, key: str) -> None: + """Delete the given key from the storage. + + Args: + - key - The key to delete. + """ + + @abstractmethod + async def clear(self) -> None: + """Clear the storage.""" + + @abstractmethod + def child(self, name: str | None) -> "PipelineStorage": + """Create a child storage instance.""" diff --git a/func-app/graphrag/common/utils/common_utils.py b/func-app/graphrag/common/utils/common_utils.py new file mode 100644 index 0000000000..d53345d245 --- /dev/null +++ b/func-app/graphrag/common/utils/common_utils.py @@ -0,0 +1,11 @@ +import uuid + +def is_valid_guid(guid_str): + """Utility to check valid Guid.""" + try: + # Attempt to create a UUID object + uuid_obj = uuid.UUID(guid_str, version=4) + # Check if the string representation matches the UUID object + return str(uuid_obj) == guid_str + except ValueError: + return False \ No newline at end of file diff --git a/func-app/graphrag/common/utils/context_utils.py b/func-app/graphrag/common/utils/context_utils.py new file mode 100644 index 0000000000..687779e9c1 --- /dev/null +++ b/func-app/graphrag/common/utils/context_utils.py @@ -0,0 +1,9 @@ +from graphrag.config import ( + GraphRagConfig, +) + +def get_files_by_contextid(config: GraphRagConfig, context_id: str): + """Utility function to get files by context id""" + # General: eventually this will be comming from cosmos db or any other storage + filesInContext = config.query_context.files + return filesInContext \ No newline at end of file diff --git a/func-app/graphrag/config/__init__.py b/func-app/graphrag/config/__init__.py new file mode 100644 index 0000000000..5870c4ae71 --- /dev/null +++ b/func-app/graphrag/config/__init__.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine default config package root.""" + +from .create_graphrag_config import ( + create_graphrag_config, +) +from .enums import ( + CacheType, + ContextSwitchType, + InputFileType, + InputType, + LLMType, + ReportingType, + StorageType, + TextEmbeddingTarget, +) +from .errors import ( + ApiKeyMissingError, + AzureApiBaseMissingError, + AzureDeploymentNameMissingError, +) +from .input_models import ( + CacheConfigInput, + ChunkingConfigInput, + ClaimExtractionConfigInput, + ClusterGraphConfigInput, + CommunityReportsConfigInput, + EmbedGraphConfigInput, + EntityExtractionConfigInput, + GlobalSearchConfigInput, + GraphRagConfigInput, + InputConfigInput, + LLMConfigInput, + LLMParametersInput, + LocalSearchConfigInput, + ParallelizationParametersInput, + ReportingConfigInput, + SnapshotsConfigInput, + StorageConfigInput, + SummarizeDescriptionsConfigInput, + TextEmbeddingConfigInput, + UmapConfigInput, +) +from .models import ( + CacheConfig, + ChunkingConfig, + ClaimExtractionConfig, + ClusterGraphConfig, + CommunityReportsConfig, + EmbedGraphConfig, + EntityExtractionConfig, + GlobalSearchConfig, + GraphRagConfig, + InputConfig, + LLMConfig, + LLMParameters, + LocalSearchConfig, + ParallelizationParameters, + QueryContextConfig, + ReportingConfig, + SnapshotsConfig, + StorageConfig, + SummarizeDescriptionsConfig, + TextEmbeddingConfig, + UmapConfig, +) +from .read_dotenv import read_dotenv + +__all__ = [ + "ApiKeyMissingError", + "AzureApiBaseMissingError", + "AzureDeploymentNameMissingError", + "CacheConfig", + "ContextSwitchType", + "CacheConfigInput", + "CacheType", + "ChunkingConfig", + "ChunkingConfigInput", + "ClaimExtractionConfig", + "ClaimExtractionConfigInput", + "ClusterGraphConfig", + "ClusterGraphConfigInput", + "CommunityReportsConfig", + "CommunityReportsConfigInput", + "EmbedGraphConfig", + "EmbedGraphConfigInput", + "EntityExtractionConfig", + "EntityExtractionConfigInput", + "GlobalSearchConfig", + "GlobalSearchConfigInput", + "GraphRagConfig", + "GraphRagConfigInput", + "InputConfig", + "InputConfigInput", + "InputFileType", + "InputType", + "LLMConfig", + "LLMConfigInput", + "LLMParameters", + "LLMParametersInput", + "LLMType", + "LocalSearchConfig", + "LocalSearchConfigInput", + "ParallelizationParameters", + "ParallelizationParametersInput", + "QueryContextConfig", + "QueryContextConfigInput", + "ReportingConfig", + "ReportingConfigInput", + "ReportingType", + "SnapshotsConfig", + "SnapshotsConfigInput", + "StorageConfig", + "StorageConfigInput", + "StorageType", + "StorageType", + "SummarizeDescriptionsConfig", + "SummarizeDescriptionsConfigInput", + "TextEmbeddingConfig", + "TextEmbeddingConfigInput", + "TextEmbeddingTarget", + "UmapConfig", + "UmapConfigInput", + "create_graphrag_config", + "read_dotenv", +] diff --git a/func-app/graphrag/config/create_graphrag_config.py b/func-app/graphrag/config/create_graphrag_config.py new file mode 100644 index 0000000000..34b953345e --- /dev/null +++ b/func-app/graphrag/config/create_graphrag_config.py @@ -0,0 +1,687 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration, loaded from environment variables.""" + +import os +from enum import Enum +from pathlib import Path +from typing import cast + +from datashaper import AsyncType +from environs import Env +from pydantic import TypeAdapter + +import graphrag.config.defaults as defs + +from .enums import ( + CacheType, + InputFileType, + InputType, + LLMType, + ReportingType, + StorageType, + TextEmbeddingTarget, +) +from .environment_reader import EnvironmentReader +from .errors import ( + ApiKeyMissingError, + AzureApiBaseMissingError, + AzureDeploymentNameMissingError, +) +from .input_models import ( + GraphRagConfigInput, + LLMConfigInput, +) +from .models import ( + CacheConfig, + ChunkingConfig, + ClaimExtractionConfig, + ClusterGraphConfig, + CommunityReportsConfig, + EmbedGraphConfig, + EntityExtractionConfig, + GlobalSearchConfig, + GraphRagConfig, + InputConfig, + LLMParameters, + LocalSearchConfig, + ParallelizationParameters, + QueryContextConfig, + ReportingConfig, + SnapshotsConfig, + StorageConfig, + SummarizeDescriptionsConfig, + TextEmbeddingConfig, + UmapConfig, + GraphDBConfig, +) +from .read_dotenv import read_dotenv + +InputModelValidator = TypeAdapter(GraphRagConfigInput) + + +def create_graphrag_config( + values: GraphRagConfigInput | None = None, root_dir: str | None = None +) -> GraphRagConfig: + """Load Configuration Parameters from a dictionary.""" + values = values or {} + root_dir = root_dir or str(Path.cwd()) + env = _make_env(root_dir) + _token_replace(cast(dict, values)) + InputModelValidator.validate_python(values, strict=True) + + reader = EnvironmentReader(env) + + def hydrate_async_type(input: LLMConfigInput, base: AsyncType) -> AsyncType: + value = input.get(Fragment.async_mode) + return AsyncType(value) if value else base + + def hydrate_llm_params( + config: LLMConfigInput, base: LLMParameters + ) -> LLMParameters: + with reader.use(config.get("llm")): + llm_type = reader.str(Fragment.type) + llm_type = LLMType(llm_type) if llm_type else base.type + api_key = reader.str(Fragment.api_key) or base.api_key + api_base = reader.str(Fragment.api_base) or base.api_base + cognitive_services_endpoint = ( + reader.str(Fragment.cognitive_services_endpoint) + or base.cognitive_services_endpoint + ) + deployment_name = ( + reader.str(Fragment.deployment_name) or base.deployment_name + ) + + if api_key is None and not _is_azure(llm_type): + raise ApiKeyMissingError + if _is_azure(llm_type): + if api_base is None: + raise AzureApiBaseMissingError + if deployment_name is None: + raise AzureDeploymentNameMissingError + + sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation) + if sleep_on_rate_limit is None: + sleep_on_rate_limit = base.sleep_on_rate_limit_recommendation + + return LLMParameters( + api_key=api_key, + type=llm_type, + api_base=api_base, + api_version=reader.str(Fragment.api_version) or base.api_version, + organization=reader.str("organization") or base.organization, + proxy=reader.str("proxy") or base.proxy, + model=reader.str("model") or base.model, + max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens, + temperature=reader.float(Fragment.temperature) or base.temperature, + top_p=reader.float(Fragment.top_p) or base.top_p, + n=reader.int(Fragment.n) or base.n, + model_supports_json=reader.bool(Fragment.model_supports_json) + or base.model_supports_json, + request_timeout=reader.float(Fragment.request_timeout) + or base.request_timeout, + cognitive_services_endpoint=cognitive_services_endpoint, + deployment_name=deployment_name, + tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm) + or base.tokens_per_minute, + requests_per_minute=reader.int("requests_per_minute", Fragment.rpm) + or base.requests_per_minute, + max_retries=reader.int(Fragment.max_retries) or base.max_retries, + max_retry_wait=reader.float(Fragment.max_retry_wait) + or base.max_retry_wait, + sleep_on_rate_limit_recommendation=sleep_on_rate_limit, + concurrent_requests=reader.int(Fragment.concurrent_requests) + or base.concurrent_requests, + ) + + def hydrate_embeddings_params( + config: LLMConfigInput, base: LLMParameters + ) -> LLMParameters: + with reader.use(config.get("llm")): + api_type = reader.str(Fragment.type) or defs.EMBEDDING_TYPE + api_type = LLMType(api_type) if api_type else defs.LLM_TYPE + api_key = reader.str(Fragment.api_key) or base.api_key + + # In a unique events where: + # - same api_bases for LLM and embeddings (both Azure) + # - different api_bases for LLM and embeddings (both Azure) + # - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important) + # - LLM uses Azure OpenAI, while embeddings uses third-party OpenAI-like API + api_base = ( + reader.str(Fragment.api_base) or base.api_base + if _is_azure(api_type) + else reader.str(Fragment.api_base) + ) + api_version = ( + reader.str(Fragment.api_version) or base.api_version + if _is_azure(api_type) + else reader.str(Fragment.api_version) + ) + api_organization = reader.str("organization") or base.organization + api_proxy = reader.str("proxy") or base.proxy + cognitive_services_endpoint = ( + reader.str(Fragment.cognitive_services_endpoint) + or base.cognitive_services_endpoint + ) + deployment_name = reader.str(Fragment.deployment_name) + + if api_key is None and not _is_azure(api_type): + raise ApiKeyMissingError(embedding=True) + if _is_azure(api_type): + if api_base is None: + raise AzureApiBaseMissingError(embedding=True) + if deployment_name is None: + raise AzureDeploymentNameMissingError(embedding=True) + + sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation) + if sleep_on_rate_limit is None: + sleep_on_rate_limit = base.sleep_on_rate_limit_recommendation + + return LLMParameters( + api_key=api_key, + type=api_type, + api_base=api_base, + api_version=api_version, + organization=api_organization, + proxy=api_proxy, + model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL, + request_timeout=reader.float(Fragment.request_timeout) + or defs.LLM_REQUEST_TIMEOUT, + cognitive_services_endpoint=cognitive_services_endpoint, + deployment_name=deployment_name, + tokens_per_minute=reader.int("tokens_per_minute", Fragment.tpm) + or defs.LLM_TOKENS_PER_MINUTE, + requests_per_minute=reader.int("requests_per_minute", Fragment.rpm) + or defs.LLM_REQUESTS_PER_MINUTE, + max_retries=reader.int(Fragment.max_retries) or defs.LLM_MAX_RETRIES, + max_retry_wait=reader.float(Fragment.max_retry_wait) + or defs.LLM_MAX_RETRY_WAIT, + sleep_on_rate_limit_recommendation=sleep_on_rate_limit, + concurrent_requests=reader.int(Fragment.concurrent_requests) + or defs.LLM_CONCURRENT_REQUESTS, + ) + + def hydrate_parallelization_params( + config: LLMConfigInput, base: ParallelizationParameters + ) -> ParallelizationParameters: + with reader.use(config.get("parallelization")): + return ParallelizationParameters( + num_threads=reader.int("num_threads", Fragment.thread_count) + or base.num_threads, + stagger=reader.float("stagger", Fragment.thread_stagger) + or base.stagger, + ) + + fallback_oai_key = env("OPENAI_API_KEY", env("AZURE_OPENAI_API_KEY", None)) + fallback_oai_org = env("OPENAI_ORG_ID", None) + fallback_oai_base = env("OPENAI_BASE_URL", None) + fallback_oai_version = env("OPENAI_API_VERSION", None) + + with reader.envvar_prefix(Section.graphrag), reader.use(values): + async_mode = reader.str(Fragment.async_mode) + async_mode = AsyncType(async_mode) if async_mode else defs.ASYNC_MODE + + fallback_oai_key = reader.str(Fragment.api_key) or fallback_oai_key + fallback_oai_org = reader.str(Fragment.api_organization) or fallback_oai_org + fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base + fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version + fallback_oai_proxy = reader.str(Fragment.api_proxy) + + with reader.envvar_prefix(Section.llm): + with reader.use(values.get("llm")): + llm_type = reader.str(Fragment.type) + llm_type = LLMType(llm_type) if llm_type else defs.LLM_TYPE + api_key = reader.str(Fragment.api_key) or fallback_oai_key + api_organization = ( + reader.str(Fragment.api_organization) or fallback_oai_org + ) + api_base = reader.str(Fragment.api_base) or fallback_oai_base + api_version = reader.str(Fragment.api_version) or fallback_oai_version + api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy + cognitive_services_endpoint = reader.str( + Fragment.cognitive_services_endpoint + ) + deployment_name = reader.str(Fragment.deployment_name) + + if api_key is None and not _is_azure(llm_type): + raise ApiKeyMissingError + if _is_azure(llm_type): + if api_base is None: + raise AzureApiBaseMissingError + if deployment_name is None: + raise AzureDeploymentNameMissingError + + sleep_on_rate_limit = reader.bool(Fragment.sleep_recommendation) + if sleep_on_rate_limit is None: + sleep_on_rate_limit = defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION + + llm_model = LLMParameters( + api_key=api_key, + api_base=api_base, + api_version=api_version, + organization=api_organization, + proxy=api_proxy, + type=llm_type, + model=reader.str(Fragment.model) or defs.LLM_MODEL, + max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS, + temperature=reader.float(Fragment.temperature) + or defs.LLM_TEMPERATURE, + top_p=reader.float(Fragment.top_p) or defs.LLM_TOP_P, + n=reader.int(Fragment.n) or defs.LLM_N, + model_supports_json=reader.bool(Fragment.model_supports_json), + request_timeout=reader.float(Fragment.request_timeout) + or defs.LLM_REQUEST_TIMEOUT, + cognitive_services_endpoint=cognitive_services_endpoint, + deployment_name=deployment_name, + tokens_per_minute=reader.int(Fragment.tpm) + or defs.LLM_TOKENS_PER_MINUTE, + requests_per_minute=reader.int(Fragment.rpm) + or defs.LLM_REQUESTS_PER_MINUTE, + max_retries=reader.int(Fragment.max_retries) + or defs.LLM_MAX_RETRIES, + max_retry_wait=reader.float(Fragment.max_retry_wait) + or defs.LLM_MAX_RETRY_WAIT, + sleep_on_rate_limit_recommendation=sleep_on_rate_limit, + concurrent_requests=reader.int(Fragment.concurrent_requests) + or defs.LLM_CONCURRENT_REQUESTS, + ) + with reader.use(values.get("parallelization")): + llm_parallelization_model = ParallelizationParameters( + stagger=reader.float("stagger", Fragment.thread_stagger) + or defs.PARALLELIZATION_STAGGER, + num_threads=reader.int("num_threads", Fragment.thread_count) + or defs.PARALLELIZATION_NUM_THREADS, + ) + embeddings_config = values.get("embeddings") or {} + with reader.envvar_prefix(Section.embedding), reader.use(embeddings_config): + embeddings_target = reader.str("target") + embeddings_model = TextEmbeddingConfig( + llm=hydrate_embeddings_params(embeddings_config, llm_model), + parallelization=hydrate_parallelization_params( + embeddings_config, llm_parallelization_model + ), + vector_store=embeddings_config.get("vector_store", None), + async_mode=hydrate_async_type(embeddings_config, async_mode), + target=( + TextEmbeddingTarget(embeddings_target) + if embeddings_target + else defs.EMBEDDING_TARGET + ), + batch_size=reader.int("batch_size") or defs.EMBEDDING_BATCH_SIZE, + batch_max_tokens=reader.int("batch_max_tokens") + or defs.EMBEDDING_BATCH_MAX_TOKENS, + skip=reader.list("skip") or [], + ) + with ( + reader.envvar_prefix(Section.node2vec), + reader.use(values.get("embed_graph")), + ): + embed_graph_model = EmbedGraphConfig( + enabled=reader.bool(Fragment.enabled) or defs.NODE2VEC_ENABLED, + num_walks=reader.int("num_walks") or defs.NODE2VEC_NUM_WALKS, + walk_length=reader.int("walk_length") or defs.NODE2VEC_WALK_LENGTH, + window_size=reader.int("window_size") or defs.NODE2VEC_WINDOW_SIZE, + iterations=reader.int("iterations") or defs.NODE2VEC_ITERATIONS, + random_seed=reader.int("random_seed") or defs.NODE2VEC_RANDOM_SEED, + ) + with reader.envvar_prefix(Section.input), reader.use(values.get("input")): + input_type = reader.str("type") + file_type = reader.str(Fragment.file_type) + input_model = InputConfig( + file_type=( + InputFileType(file_type) if file_type else defs.INPUT_FILE_TYPE + ), + type=(InputType(input_type) if input_type else defs.INPUT_TYPE), + encoding=reader.str("file_encoding", Fragment.encoding) + or defs.INPUT_FILE_ENCODING, + base_dir=reader.str(Fragment.base_dir) or defs.INPUT_BASE_DIR, + file_pattern=reader.str("file_pattern") + or ( + defs.INPUT_TEXT_PATTERN + if file_type == InputFileType.text + else defs.INPUT_CSV_PATTERN + ), + source_column=reader.str("source_column"), + timestamp_column=reader.str("timestamp_column"), + timestamp_format=reader.str("timestamp_format"), + text_column=reader.str("text_column") or defs.INPUT_TEXT_COLUMN, + title_column=reader.str("title_column"), + document_attribute_columns=reader.list("document_attribute_columns") + or [], + connection_string=reader.str(Fragment.conn_string), + storage_account_blob_url=reader.str(Fragment.storage_account_blob_url), + container_name=reader.str(Fragment.container_name), + ) + with reader.envvar_prefix(Section.cache), reader.use(values.get("cache")): + c_type = reader.str(Fragment.type) + cache_model = CacheConfig( + type=CacheType(c_type) if c_type else defs.CACHE_TYPE, + connection_string=reader.str(Fragment.conn_string), + storage_account_blob_url=reader.str(Fragment.storage_account_blob_url), + container_name=reader.str(Fragment.container_name), + base_dir=reader.str(Fragment.base_dir) or defs.CACHE_BASE_DIR, + ) + with ( + reader.envvar_prefix(Section.reporting), + reader.use(values.get("reporting")), + ): + r_type = reader.str(Fragment.type) + reporting_model = ReportingConfig( + type=ReportingType(r_type) if r_type else defs.REPORTING_TYPE, + connection_string=reader.str(Fragment.conn_string), + storage_account_blob_url=reader.str(Fragment.storage_account_blob_url), + container_name=reader.str(Fragment.container_name), + base_dir=reader.str(Fragment.base_dir) or defs.REPORTING_BASE_DIR, + ) + with reader.envvar_prefix(Section.storage), reader.use(values.get("storage")): + s_type = reader.str(Fragment.type) + storage_model = StorageConfig( + type=StorageType(s_type) if s_type else defs.STORAGE_TYPE, + connection_string=reader.str(Fragment.conn_string), + storage_account_blob_url=reader.str(Fragment.storage_account_blob_url), + container_name=reader.str(Fragment.container_name), + base_dir=reader.str(Fragment.base_dir) or defs.STORAGE_BASE_DIR, + overwrite=reader.bool(Fragment.overwrite) or False + ) + with reader.envvar_prefix(Section.chunk), reader.use(values.get("chunks")): + group_by_columns = reader.list("group_by_columns", "BY_COLUMNS") + if group_by_columns is None: + group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS + + chunks_model = ChunkingConfig( + size=reader.int("size") or defs.CHUNK_SIZE, + overlap=reader.int("overlap") or defs.CHUNK_OVERLAP, + group_by_columns=group_by_columns, + encoding_model=reader.str(Fragment.encoding_model), + ) + with ( + reader.envvar_prefix(Section.snapshot), + reader.use(values.get("snapshots")), + ): + snapshots_model = SnapshotsConfig( + graphml=reader.bool("graphml") or defs.SNAPSHOTS_GRAPHML, + raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES, + top_level_nodes=reader.bool("top_level_nodes") + or defs.SNAPSHOTS_TOP_LEVEL_NODES, + ) + with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")): + umap_model = UmapConfig( + enabled=reader.bool(Fragment.enabled) or defs.UMAP_ENABLED, + ) + + entity_extraction_config = values.get("entity_extraction") or {} + with ( + reader.envvar_prefix(Section.entity_extraction), + reader.use(entity_extraction_config), + ): + max_gleanings = reader.int(Fragment.max_gleanings) + max_gleanings = ( + max_gleanings + if max_gleanings is not None + else defs.ENTITY_EXTRACTION_MAX_GLEANINGS + ) + + entity_extraction_model = EntityExtractionConfig( + llm=hydrate_llm_params(entity_extraction_config, llm_model), + parallelization=hydrate_parallelization_params( + entity_extraction_config, llm_parallelization_model + ), + async_mode=hydrate_async_type(entity_extraction_config, async_mode), + entity_types=reader.list("entity_types") + or defs.ENTITY_EXTRACTION_ENTITY_TYPES, + max_gleanings=max_gleanings, + prompt=reader.str("prompt", Fragment.prompt_file), + encoding_model=reader.str(Fragment.encoding_model), + ) + + claim_extraction_config = values.get("claim_extraction") or {} + with ( + reader.envvar_prefix(Section.claim_extraction), + reader.use(claim_extraction_config), + ): + max_gleanings = reader.int(Fragment.max_gleanings) + max_gleanings = ( + max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS + ) + claim_extraction_model = ClaimExtractionConfig( + enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED, + llm=hydrate_llm_params(claim_extraction_config, llm_model), + parallelization=hydrate_parallelization_params( + claim_extraction_config, llm_parallelization_model + ), + async_mode=hydrate_async_type(claim_extraction_config, async_mode), + description=reader.str("description") or defs.CLAIM_DESCRIPTION, + prompt=reader.str("prompt", Fragment.prompt_file), + max_gleanings=max_gleanings, + encoding_model=reader.str(Fragment.encoding_model), + ) + + community_report_config = values.get("community_reports") or {} + with ( + reader.envvar_prefix(Section.community_reports), + reader.use(community_report_config), + ): + community_reports_model = CommunityReportsConfig( + llm=hydrate_llm_params(community_report_config, llm_model), + parallelization=hydrate_parallelization_params( + community_report_config, llm_parallelization_model + ), + async_mode=hydrate_async_type(community_report_config, async_mode), + prompt=reader.str("prompt", Fragment.prompt_file), + max_length=reader.int(Fragment.max_length) + or defs.COMMUNITY_REPORT_MAX_LENGTH, + max_input_length=reader.int("max_input_length") + or defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH, + ) + + summarize_description_config = values.get("summarize_descriptions") or {} + with ( + reader.envvar_prefix(Section.summarize_descriptions), + reader.use(values.get("summarize_descriptions")), + ): + summarize_descriptions_model = SummarizeDescriptionsConfig( + llm=hydrate_llm_params(summarize_description_config, llm_model), + parallelization=hydrate_parallelization_params( + summarize_description_config, llm_parallelization_model + ), + async_mode=hydrate_async_type(summarize_description_config, async_mode), + prompt=reader.str("prompt", Fragment.prompt_file), + max_length=reader.int(Fragment.max_length) + or defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH, + ) + + with reader.use(values.get("cluster_graph")): + cluster_graph_model = ClusterGraphConfig( + max_cluster_size=reader.int("max_cluster_size") or defs.MAX_CLUSTER_SIZE + ) + + with ( + reader.use(values.get("local_search")), + reader.envvar_prefix(Section.local_search), + ): + local_search_model = LocalSearchConfig( + text_unit_prop=reader.float("text_unit_prop") + or defs.LOCAL_SEARCH_TEXT_UNIT_PROP, + community_prop=reader.float("community_prop") + or defs.LOCAL_SEARCH_COMMUNITY_PROP, + conversation_history_max_turns=reader.int( + "conversation_history_max_turns" + ) + or defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS, + top_k_entities=reader.int("top_k_entities") + or defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES, + top_k_relationships=reader.int("top_k_relationships") + or defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS, + temperature=reader.float("llm_temperature") + or defs.LOCAL_SEARCH_LLM_TEMPERATURE, + top_p=reader.float("llm_top_p") or defs.LOCAL_SEARCH_LLM_TOP_P, + n=reader.int("llm_n") or defs.LOCAL_SEARCH_LLM_N, + max_tokens=reader.int(Fragment.max_tokens) + or defs.LOCAL_SEARCH_MAX_TOKENS, + llm_max_tokens=reader.int("llm_max_tokens") + or defs.LOCAL_SEARCH_LLM_MAX_TOKENS, + ) + + with ( + reader.use(values.get("global_search")), + reader.envvar_prefix(Section.global_search), + ): + global_search_model = GlobalSearchConfig( + temperature=reader.float("llm_temperature") + or defs.GLOBAL_SEARCH_LLM_TEMPERATURE, + top_p=reader.float("llm_top_p") or defs.GLOBAL_SEARCH_LLM_TOP_P, + n=reader.int("llm_n") or defs.GLOBAL_SEARCH_LLM_N, + max_tokens=reader.int(Fragment.max_tokens) + or defs.GLOBAL_SEARCH_MAX_TOKENS, + data_max_tokens=reader.int("data_max_tokens") + or defs.GLOBAL_SEARCH_DATA_MAX_TOKENS, + map_max_tokens=reader.int("map_max_tokens") + or defs.GLOBAL_SEARCH_MAP_MAX_TOKENS, + reduce_max_tokens=reader.int("reduce_max_tokens") + or defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS, + concurrency=reader.int("concurrency") or defs.GLOBAL_SEARCH_CONCURRENCY, + ) + + with ( + reader.use(values.get("query_context")), + reader.envvar_prefix(Section.query_context), + ): + query_context_model = QueryContextConfig( + files=reader.list("files") or [], + ) + + with ( + reader.use(values.get("graphdb")), + reader.envvar_prefix(Section.query_context), + ): + graphdb_model = GraphDBConfig( + account_name=reader.str("account_name") or None, + account_key=reader.str("account_key") or None, + username=reader.str("username") or None, + enabled=reader.bool("enabled") or False, + cosmos_url=reader.str("cosmos_url") or None, + gremlin_url=reader.str("gremlin_url") or None, + ) + + encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL + skip_workflows = reader.list("skip_workflows") or [] + + return GraphRagConfig( + root_dir=root_dir, + llm=llm_model, + parallelization=llm_parallelization_model, + async_mode=async_mode, + embeddings=embeddings_model, + embed_graph=embed_graph_model, + reporting=reporting_model, + storage=storage_model, + cache=cache_model, + input=input_model, + chunks=chunks_model, + snapshots=snapshots_model, + entity_extraction=entity_extraction_model, + claim_extraction=claim_extraction_model, + community_reports=community_reports_model, + summarize_descriptions=summarize_descriptions_model, + umap=umap_model, + cluster_graph=cluster_graph_model, + encoding_model=encoding_model, + skip_workflows=skip_workflows, + local_search=local_search_model, + global_search=global_search_model, + query_context=query_context_model, + graphdb=graphdb_model, + ) + + +class Fragment(str, Enum): + """Configuration Fragments.""" + + api_base = "API_BASE" + api_key = "API_KEY" + api_version = "API_VERSION" + api_organization = "API_ORGANIZATION" + api_proxy = "API_PROXY" + async_mode = "ASYNC_MODE" + base_dir = "BASE_DIR" + overwrite = "Overwrite" + cognitive_services_endpoint = "COGNITIVE_SERVICES_ENDPOINT" + concurrent_requests = "CONCURRENT_REQUESTS" + conn_string = "CONNECTION_STRING" + container_name = "CONTAINER_NAME" + deployment_name = "DEPLOYMENT_NAME" + description = "DESCRIPTION" + enabled = "ENABLED" + encoding = "ENCODING" + encoding_model = "ENCODING_MODEL" + file_type = "FILE_TYPE" + max_gleanings = "MAX_GLEANINGS" + max_length = "MAX_LENGTH" + max_retries = "MAX_RETRIES" + max_retry_wait = "MAX_RETRY_WAIT" + max_tokens = "MAX_TOKENS" + temperature = "TEMPERATURE" + top_p = "TOP_P" + n = "N" + model = "MODEL" + model_supports_json = "MODEL_SUPPORTS_JSON" + prompt_file = "PROMPT_FILE" + request_timeout = "REQUEST_TIMEOUT" + rpm = "REQUESTS_PER_MINUTE" + sleep_recommendation = "SLEEP_ON_RATE_LIMIT_RECOMMENDATION" + storage_account_blob_url = "STORAGE_ACCOUNT_BLOB_URL" + thread_count = "THREAD_COUNT" + thread_stagger = "THREAD_STAGGER" + tpm = "TOKENS_PER_MINUTE" + type = "TYPE" + output = "OUTPUT" + + +class Section(str, Enum): + """Configuration Sections.""" + + base = "BASE" + cache = "CACHE" + chunk = "CHUNK" + claim_extraction = "CLAIM_EXTRACTION" + community_reports = "COMMUNITY_REPORTS" + embedding = "EMBEDDING" + entity_extraction = "ENTITY_EXTRACTION" + graphrag = "GRAPHRAG" + input = "INPUT" + llm = "LLM" + node2vec = "NODE2VEC" + reporting = "REPORTING" + snapshot = "SNAPSHOT" + storage = "STORAGE" + summarize_descriptions = "SUMMARIZE_DESCRIPTIONS" + umap = "UMAP" + local_search = "LOCAL_SEARCH" + global_search = "GLOBAL_SEARCH" + query_context = "QUERY_CONTEXT" + graphdb = "GRAPHDB" + + +def _is_azure(llm_type: LLMType | None) -> bool: + return ( + llm_type == LLMType.AzureOpenAI + or llm_type == LLMType.AzureOpenAIChat + or llm_type == LLMType.AzureOpenAIEmbedding + ) + + +def _make_env(root_dir: str) -> Env: + read_dotenv(root_dir) + env = Env(expand_vars=True) + env.read_env() + return env + + +def _token_replace(data: dict): + """Replace env-var tokens in a dictionary object.""" + for key, value in data.items(): + if isinstance(value, dict): + _token_replace(value) + elif isinstance(value, str): + data[key] = os.path.expandvars(value) diff --git a/func-app/graphrag/config/defaults.py b/func-app/graphrag/config/defaults.py new file mode 100644 index 0000000000..4d6489140a --- /dev/null +++ b/func-app/graphrag/config/defaults.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Common default configuration values.""" + +from datashaper import AsyncType + +from .enums import ( + CacheType, + InputFileType, + InputType, + LLMType, + ReportingType, + StorageType, + TextEmbeddingTarget, +) + +ASYNC_MODE = AsyncType.Threaded +ENCODING_MODEL = "cl100k_base" +# +# LLM Parameters +# +LLM_TYPE = LLMType.OpenAIChat +LLM_MODEL = "gpt-4-turbo-preview" +LLM_MAX_TOKENS = 4000 +LLM_TEMPERATURE = 0 +LLM_TOP_P = 1 +LLM_N = 1 +LLM_REQUEST_TIMEOUT = 180.0 +LLM_TOKENS_PER_MINUTE = 0 +LLM_REQUESTS_PER_MINUTE = 0 +LLM_MAX_RETRIES = 10 +LLM_MAX_RETRY_WAIT = 10.0 +LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION = True +LLM_CONCURRENT_REQUESTS = 25 + +# +# Text Embedding Parameters +# +EMBEDDING_TYPE = LLMType.OpenAIEmbedding +EMBEDDING_MODEL = "text-embedding-3-small" +EMBEDDING_BATCH_SIZE = 16 +EMBEDDING_BATCH_MAX_TOKENS = 8191 +EMBEDDING_TARGET = TextEmbeddingTarget.required + +CACHE_TYPE = CacheType.file +CACHE_BASE_DIR = "cache" +CHUNK_SIZE = 1200 +CHUNK_OVERLAP = 100 +CHUNK_GROUP_BY_COLUMNS = ["id"] +CLAIM_DESCRIPTION = ( + "Any claims or facts that could be relevant to information discovery." +) +CLAIM_MAX_GLEANINGS = 1 +CLAIM_EXTRACTION_ENABLED = False +MAX_CLUSTER_SIZE = 10 +COMMUNITY_REPORT_MAX_LENGTH = 2000 +COMMUNITY_REPORT_MAX_INPUT_LENGTH = 8000 +ENTITY_EXTRACTION_ENTITY_TYPES = ["organization", "person", "geo", "event"] +ENTITY_EXTRACTION_MAX_GLEANINGS = 1 +INPUT_FILE_TYPE = InputFileType.text +INPUT_TYPE = InputType.file +INPUT_BASE_DIR = "input" +INPUT_FILE_ENCODING = "utf-8" +INPUT_TEXT_COLUMN = "text" +INPUT_CSV_PATTERN = ".*\\.csv$" +INPUT_TEXT_PATTERN = ".*\\.txt$" +PARALLELIZATION_STAGGER = 0.3 +PARALLELIZATION_NUM_THREADS = 50 +NODE2VEC_ENABLED = False +NODE2VEC_NUM_WALKS = 10 +NODE2VEC_WALK_LENGTH = 40 +NODE2VEC_WINDOW_SIZE = 2 +NODE2VEC_ITERATIONS = 3 +NODE2VEC_RANDOM_SEED = 597832 +REPORTING_TYPE = ReportingType.file +REPORTING_BASE_DIR = "output/${timestamp}/reports" +SNAPSHOTS_GRAPHML = False +SNAPSHOTS_RAW_ENTITIES = False +SNAPSHOTS_TOP_LEVEL_NODES = False +STORAGE_BASE_DIR = "output/${timestamp}/artifacts" +STORAGE_TYPE = StorageType.file +SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500 +UMAP_ENABLED = False + +# Local Search +LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5 +LOCAL_SEARCH_COMMUNITY_PROP = 0.1 +LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS = 5 +LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10 +LOCAL_SEARCH_TOP_K_RELATIONSHIPS = 10 +LOCAL_SEARCH_MAX_TOKENS = 12_000 +LOCAL_SEARCH_LLM_TEMPERATURE = 0 +LOCAL_SEARCH_LLM_TOP_P = 1 +LOCAL_SEARCH_LLM_N = 1 +LOCAL_SEARCH_LLM_MAX_TOKENS = 2000 + +# Global Search +GLOBAL_SEARCH_LLM_TEMPERATURE = 0 +GLOBAL_SEARCH_LLM_TOP_P = 1 +GLOBAL_SEARCH_LLM_N = 1 +GLOBAL_SEARCH_MAX_TOKENS = 12_000 +GLOBAL_SEARCH_DATA_MAX_TOKENS = 12_000 +GLOBAL_SEARCH_MAP_MAX_TOKENS = 1000 +GLOBAL_SEARCH_REDUCE_MAX_TOKENS = 2_000 +GLOBAL_SEARCH_CONCURRENCY = 32 diff --git a/func-app/graphrag/config/enums.py b/func-app/graphrag/config/enums.py new file mode 100644 index 0000000000..4745acc5f5 --- /dev/null +++ b/func-app/graphrag/config/enums.py @@ -0,0 +1,127 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig' and 'PipelineMemoryCacheConfig' models.""" + +from __future__ import annotations + +from enum import Enum + + +class CacheType(str, Enum): + """The cache configuration type for the pipeline.""" + + file = "file" + """The file cache configuration type.""" + memory = "memory" + """The memory cache configuration type.""" + none = "none" + """The none cache configuration type.""" + blob = "blob" + """The blob cache configuration type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class InputFileType(str, Enum): + """The input file type for the pipeline.""" + + csv = "csv" + """The CSV input type.""" + text = "text" + """The text input type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class InputType(str, Enum): + """The input type for the pipeline.""" + + file = "file" + """The file storage type.""" + blob = "blob" + """The blob storage type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class StorageType(str, Enum): + """The storage type for the pipeline.""" + + file = "file" + """The file storage type.""" + memory = "memory" + """The memory storage type.""" + blob = "blob" + """The blob storage type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class ReportingType(str, Enum): + """The reporting configuration type for the pipeline.""" + + file = "file" + """The file reporting configuration type.""" + console = "console" + """The console reporting configuration type.""" + blob = "blob" + """The blob reporting configuration type.""" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class TextEmbeddingTarget(str, Enum): + """The target to use for text embeddings.""" + + all = "all" + required = "required" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +class LLMType(str, Enum): + """LLMType enum class definition.""" + + # Embeddings + OpenAIEmbedding = "openai_embedding" + AzureOpenAIEmbedding = "azure_openai_embedding" + + # Raw Completion + OpenAI = "openai" + AzureOpenAI = "azure_openai" + + # Chat Completion + OpenAIChat = "openai_chat" + AzureOpenAIChat = "azure_openai_chat" + + # Debug + StaticResponse = "static_response" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + +class ContextSwitchType(str, Enum): + """context switcher type.""" + + #context switch types + Activate = "activate" + Deactivate= "deactivate" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + diff --git a/func-app/graphrag/config/environment_reader.py b/func-app/graphrag/config/environment_reader.py new file mode 100644 index 0000000000..258422666c --- /dev/null +++ b/func-app/graphrag/config/environment_reader.py @@ -0,0 +1,155 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A configuration reader utility class.""" + +from collections.abc import Callable +from contextlib import contextmanager +from enum import Enum +from typing import Any, TypeVar + +from environs import Env + +T = TypeVar("T") + +KeyValue = str | Enum +EnvKeySet = str | list[str] + + +def read_key(value: KeyValue) -> str: + """Read a key value.""" + if not isinstance(value, str): + return value.value.lower() + return value.lower() + + +class EnvironmentReader: + """A configuration reader utility class.""" + + _env: Env + _config_stack: list[dict] + + def __init__(self, env: Env): + self._env = env + self._config_stack = [] + + @property + def env(self): + """Get the environment object.""" + return self._env + + def _read_env( + self, env_key: str | list[str], default_value: T, read: Callable[[str, T], T] + ) -> T | None: + if isinstance(env_key, str): + env_key = [env_key] + + for k in env_key: + result = read(k.upper(), default_value) + if result is not default_value: + return result + + return default_value + + def envvar_prefix(self, prefix: KeyValue): + """Set the environment variable prefix.""" + prefix = read_key(prefix) + prefix = f"{prefix}_".upper() + return self._env.prefixed(prefix) + + def use(self, value: Any | None): + """Create a context manager to push the value into the config_stack.""" + + @contextmanager + def config_context(): + self._config_stack.append(value or {}) + try: + yield + finally: + self._config_stack.pop() + + return config_context() + + @property + def section(self) -> dict: + """Get the current section.""" + return self._config_stack[-1] if self._config_stack else {} + + def str( + self, + key: KeyValue, + env_key: EnvKeySet | None = None, + default_value: str | None = None, + ) -> str | None: + """Read a configuration value.""" + key = read_key(key) + if self.section and key in self.section: + return self.section[key] + + return self._read_env( + env_key or key, default_value, (lambda k, dv: self._env(k, dv)) + ) + + def int( + self, + key: KeyValue, + env_key: EnvKeySet | None = None, + default_value: int | None = None, + ) -> int | None: + """Read an integer configuration value.""" + key = read_key(key) + if self.section and key in self.section: + return int(self.section[key]) + return self._read_env( + env_key or key, default_value, lambda k, dv: self._env.int(k, dv) + ) + + def bool( + self, + key: KeyValue, + env_key: EnvKeySet | None = None, + default_value: bool | None = None, + ) -> bool | None: + """Read an integer configuration value.""" + key = read_key(key) + if self.section and key in self.section: + return bool(self.section[key]) + + return self._read_env( + env_key or key, default_value, lambda k, dv: self._env.bool(k, dv) + ) + + def float( + self, + key: KeyValue, + env_key: EnvKeySet | None = None, + default_value: float | None = None, + ) -> float | None: + """Read a float configuration value.""" + key = read_key(key) + if self.section and key in self.section: + return float(self.section[key]) + return self._read_env( + env_key or key, default_value, lambda k, dv: self._env.float(k, dv) + ) + + def list( + self, + key: KeyValue, + env_key: EnvKeySet | None = None, + default_value: list | None = None, + ) -> list | None: + """Parse an list configuration value.""" + key = read_key(key) + result = None + if self.section and key in self.section: + result = self.section[key] + if isinstance(result, list): + return result + + if result is None: + result = self.str(key, env_key) + if result is not None: + result = [s.strip() for s in result.split(",")] + return [s for s in result if s] + return default_value diff --git a/func-app/graphrag/config/errors.py b/func-app/graphrag/config/errors.py new file mode 100644 index 0000000000..9a2161b8af --- /dev/null +++ b/func-app/graphrag/config/errors.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Errors for the default configuration.""" + + +class ApiKeyMissingError(ValueError): + """LLM Key missing error.""" + + def __init__(self, embedding: bool = False) -> None: + """Init method definition.""" + api_type = "Embedding" if embedding else "Completion" + api_key = "GRAPHRAG_EMBEDDING_API_KEY" if embedding else "GRAPHRAG_LLM_API_KEY" + msg = f"API Key is required for {api_type} API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or {api_key} environment variable." + super().__init__(msg) + + +class AzureApiBaseMissingError(ValueError): + """Azure API Base missing error.""" + + def __init__(self, embedding: bool = False) -> None: + """Init method definition.""" + api_type = "Embedding" if embedding else "Completion" + api_base = "GRAPHRAG_EMBEDDING_API_BASE" if embedding else "GRAPHRAG_API_BASE" + msg = f"API Base is required for {api_type} API. Please set either the OPENAI_API_BASE, GRAPHRAG_API_BASE or {api_base} environment variable." + super().__init__(msg) + + +class AzureDeploymentNameMissingError(ValueError): + """Azure Deployment Name missing error.""" + + def __init__(self, embedding: bool = False) -> None: + """Init method definition.""" + api_type = "Embedding" if embedding else "Completion" + api_base = ( + "GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME" + if embedding + else "GRAPHRAG_LLM_DEPLOYMENT_NAME" + ) + msg = f"Deployment Name is required for {api_type} API. Please set either the OPENAI_DEPLOYMENT_NAME, GRAPHRAG_LLM_DEPLOYMENT_NAME or {api_base} environment variable." + super().__init__(msg) diff --git a/func-app/graphrag/config/input_models/__init__.py b/func-app/graphrag/config/input_models/__init__.py new file mode 100644 index 0000000000..f905ae38b2 --- /dev/null +++ b/func-app/graphrag/config/input_models/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Interfaces for Default Config parameterization.""" + +from .cache_config_input import CacheConfigInput +from .chunking_config_input import ChunkingConfigInput +from .claim_extraction_config_input import ClaimExtractionConfigInput +from .cluster_graph_config_input import ClusterGraphConfigInput +from .community_reports_config_input import CommunityReportsConfigInput +from .embed_graph_config_input import EmbedGraphConfigInput +from .entity_extraction_config_input import EntityExtractionConfigInput +from .global_search_config_input import GlobalSearchConfigInput +from .graphrag_config_input import GraphRagConfigInput +from .input_config_input import InputConfigInput +from .llm_config_input import LLMConfigInput +from .llm_parameters_input import LLMParametersInput +from .local_search_config_input import LocalSearchConfigInput +from .parallelization_parameters_input import ParallelizationParametersInput +from .reporting_config_input import ReportingConfigInput +from .snapshots_config_input import SnapshotsConfigInput +from .storage_config_input import StorageConfigInput +from .summarize_descriptions_config_input import ( + SummarizeDescriptionsConfigInput, +) +from .text_embedding_config_input import TextEmbeddingConfigInput +from .umap_config_input import UmapConfigInput + +__all__ = [ + "CacheConfigInput", + "ChunkingConfigInput", + "ClaimExtractionConfigInput", + "ClusterGraphConfigInput", + "CommunityReportsConfigInput", + "EmbedGraphConfigInput", + "EntityExtractionConfigInput", + "GlobalSearchConfigInput", + "GraphRagConfigInput", + "InputConfigInput", + "LLMConfigInput", + "LLMParametersInput", + "LocalSearchConfigInput", + "ParallelizationParametersInput", + "ReportingConfigInput", + "SnapshotsConfigInput", + "StorageConfigInput", + "SummarizeDescriptionsConfigInput", + "TextEmbeddingConfigInput", + "UmapConfigInput", +] diff --git a/func-app/graphrag/config/input_models/cache_config_input.py b/func-app/graphrag/config/input_models/cache_config_input.py new file mode 100644 index 0000000000..fe88d35b44 --- /dev/null +++ b/func-app/graphrag/config/input_models/cache_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + +from graphrag.config.enums import CacheType + + +class CacheConfigInput(TypedDict): + """The default configuration section for Cache.""" + + type: NotRequired[CacheType | str | None] + base_dir: NotRequired[str | None] + connection_string: NotRequired[str | None] + container_name: NotRequired[str | None] + storage_account_blob_url: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/chunking_config_input.py b/func-app/graphrag/config/input_models/chunking_config_input.py new file mode 100644 index 0000000000..bbf4fc735f --- /dev/null +++ b/func-app/graphrag/config/input_models/chunking_config_input.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class ChunkingConfigInput(TypedDict): + """Configuration section for chunking.""" + + size: NotRequired[int | str | None] + overlap: NotRequired[int | str | None] + group_by_columns: NotRequired[list[str] | str | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/claim_extraction_config_input.py b/func-app/graphrag/config/input_models/claim_extraction_config_input.py new file mode 100644 index 0000000000..f23e31d0a7 --- /dev/null +++ b/func-app/graphrag/config/input_models/claim_extraction_config_input.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from .llm_config_input import LLMConfigInput + + +class ClaimExtractionConfigInput(LLMConfigInput): + """Configuration section for claim extraction.""" + + enabled: NotRequired[bool | None] + prompt: NotRequired[str | None] + description: NotRequired[str | None] + max_gleanings: NotRequired[int | str | None] + strategy: NotRequired[dict | None] + encoding_model: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/cluster_graph_config_input.py b/func-app/graphrag/config/input_models/cluster_graph_config_input.py new file mode 100644 index 0000000000..eb6f9cd1c6 --- /dev/null +++ b/func-app/graphrag/config/input_models/cluster_graph_config_input.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class ClusterGraphConfigInput(TypedDict): + """Configuration section for clustering graphs.""" + + max_cluster_size: NotRequired[int | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/community_reports_config_input.py b/func-app/graphrag/config/input_models/community_reports_config_input.py new file mode 100644 index 0000000000..79ae3152e7 --- /dev/null +++ b/func-app/graphrag/config/input_models/community_reports_config_input.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from .llm_config_input import LLMConfigInput + + +class CommunityReportsConfigInput(LLMConfigInput): + """Configuration section for community reports.""" + + prompt: NotRequired[str | None] + max_length: NotRequired[int | str | None] + max_input_length: NotRequired[int | str | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/embed_graph_config_input.py b/func-app/graphrag/config/input_models/embed_graph_config_input.py new file mode 100644 index 0000000000..f8b6ee6faf --- /dev/null +++ b/func-app/graphrag/config/input_models/embed_graph_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class EmbedGraphConfigInput(TypedDict): + """The default configuration section for Node2Vec.""" + + enabled: NotRequired[bool | str | None] + num_walks: NotRequired[int | str | None] + walk_length: NotRequired[int | str | None] + window_size: NotRequired[int | str | None] + iterations: NotRequired[int | str | None] + random_seed: NotRequired[int | str | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/entity_extraction_config_input.py b/func-app/graphrag/config/input_models/entity_extraction_config_input.py new file mode 100644 index 0000000000..f1d3587e99 --- /dev/null +++ b/func-app/graphrag/config/input_models/entity_extraction_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from .llm_config_input import LLMConfigInput + + +class EntityExtractionConfigInput(LLMConfigInput): + """Configuration section for entity extraction.""" + + prompt: NotRequired[str | None] + entity_types: NotRequired[list[str] | str | None] + max_gleanings: NotRequired[int | str | None] + strategy: NotRequired[dict | None] + encoding_model: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/global_search_config_input.py b/func-app/graphrag/config/input_models/global_search_config_input.py new file mode 100644 index 0000000000..e13fbbfa9e --- /dev/null +++ b/func-app/graphrag/config/input_models/global_search_config_input.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class GlobalSearchConfigInput(TypedDict): + """The default configuration section for Cache.""" + + max_tokens: NotRequired[int | str | None] + data_max_tokens: NotRequired[int | str | None] + map_max_tokens: NotRequired[int | str | None] + reduce_max_tokens: NotRequired[int | str | None] + concurrency: NotRequired[int | str | None] diff --git a/func-app/graphrag/config/input_models/graphrag_config_input.py b/func-app/graphrag/config/input_models/graphrag_config_input.py new file mode 100644 index 0000000000..7c04dea2e3 --- /dev/null +++ b/func-app/graphrag/config/input_models/graphrag_config_input.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from .cache_config_input import CacheConfigInput +from .chunking_config_input import ChunkingConfigInput +from .claim_extraction_config_input import ClaimExtractionConfigInput +from .cluster_graph_config_input import ClusterGraphConfigInput +from .community_reports_config_input import CommunityReportsConfigInput +from .embed_graph_config_input import EmbedGraphConfigInput +from .entity_extraction_config_input import EntityExtractionConfigInput +from .global_search_config_input import GlobalSearchConfigInput +from .input_config_input import InputConfigInput +from .llm_config_input import LLMConfigInput +from .local_search_config_input import LocalSearchConfigInput +from .reporting_config_input import ReportingConfigInput +from .snapshots_config_input import SnapshotsConfigInput +from .storage_config_input import StorageConfigInput +from .summarize_descriptions_config_input import ( + SummarizeDescriptionsConfigInput, +) +from .text_embedding_config_input import TextEmbeddingConfigInput +from .umap_config_input import UmapConfigInput + + +class GraphRagConfigInput(LLMConfigInput): + """Base class for the Default-Configuration parameterization settings.""" + + reporting: NotRequired[ReportingConfigInput | None] + storage: NotRequired[StorageConfigInput | None] + cache: NotRequired[CacheConfigInput | None] + input: NotRequired[InputConfigInput | None] + embed_graph: NotRequired[EmbedGraphConfigInput | None] + embeddings: NotRequired[TextEmbeddingConfigInput | None] + chunks: NotRequired[ChunkingConfigInput | None] + snapshots: NotRequired[SnapshotsConfigInput | None] + entity_extraction: NotRequired[EntityExtractionConfigInput | None] + summarize_descriptions: NotRequired[SummarizeDescriptionsConfigInput | None] + community_reports: NotRequired[CommunityReportsConfigInput | None] + claim_extraction: NotRequired[ClaimExtractionConfigInput | None] + cluster_graph: NotRequired[ClusterGraphConfigInput | None] + umap: NotRequired[UmapConfigInput | None] + encoding_model: NotRequired[str | None] + skip_workflows: NotRequired[list[str] | str | None] + local_search: NotRequired[LocalSearchConfigInput | None] + global_search: NotRequired[GlobalSearchConfigInput | None] diff --git a/func-app/graphrag/config/input_models/input_config_input.py b/func-app/graphrag/config/input_models/input_config_input.py new file mode 100644 index 0000000000..4ff89d2c9a --- /dev/null +++ b/func-app/graphrag/config/input_models/input_config_input.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + +from graphrag.config.enums import InputFileType, InputType + + +class InputConfigInput(TypedDict): + """The default configuration section for Input.""" + + type: NotRequired[InputType | str | None] + file_type: NotRequired[InputFileType | str | None] + base_dir: NotRequired[str | None] + connection_string: NotRequired[str | None] + container_name: NotRequired[str | None] + file_encoding: NotRequired[str | None] + file_pattern: NotRequired[str | None] + source_column: NotRequired[str | None] + timestamp_column: NotRequired[str | None] + timestamp_format: NotRequired[str | None] + text_column: NotRequired[str | None] + title_column: NotRequired[str | None] + document_attribute_columns: NotRequired[list[str] | str | None] + storage_account_blob_url: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/llm_config_input.py b/func-app/graphrag/config/input_models/llm_config_input.py new file mode 100644 index 0000000000..67231371b8 --- /dev/null +++ b/func-app/graphrag/config/input_models/llm_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from datashaper import AsyncType +from typing_extensions import NotRequired, TypedDict + +from .llm_parameters_input import LLMParametersInput +from .parallelization_parameters_input import ParallelizationParametersInput + + +class LLMConfigInput(TypedDict): + """Base class for LLM-configured steps.""" + + llm: NotRequired[LLMParametersInput | None] + parallelization: NotRequired[ParallelizationParametersInput | None] + async_mode: NotRequired[AsyncType | str | None] diff --git a/func-app/graphrag/config/input_models/llm_parameters_input.py b/func-app/graphrag/config/input_models/llm_parameters_input.py new file mode 100644 index 0000000000..c89c6c0922 --- /dev/null +++ b/func-app/graphrag/config/input_models/llm_parameters_input.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Parameters model.""" + +from typing_extensions import NotRequired, TypedDict + +from graphrag.config.enums import LLMType + + +class LLMParametersInput(TypedDict): + """LLM Parameters model.""" + + api_key: NotRequired[str | None] + type: NotRequired[LLMType | str | None] + model: NotRequired[str | None] + max_tokens: NotRequired[int | str | None] + request_timeout: NotRequired[float | str | None] + api_base: NotRequired[str | None] + api_version: NotRequired[str | None] + organization: NotRequired[str | None] + proxy: NotRequired[str | None] + cognitive_services_endpoint: NotRequired[str | None] + deployment_name: NotRequired[str | None] + model_supports_json: NotRequired[bool | str | None] + tokens_per_minute: NotRequired[int | str | None] + requests_per_minute: NotRequired[int | str | None] + max_retries: NotRequired[int | str | None] + max_retry_wait: NotRequired[float | str | None] + sleep_on_rate_limit_recommendation: NotRequired[bool | str | None] + concurrent_requests: NotRequired[int | str | None] diff --git a/func-app/graphrag/config/input_models/local_search_config_input.py b/func-app/graphrag/config/input_models/local_search_config_input.py new file mode 100644 index 0000000000..23df40102a --- /dev/null +++ b/func-app/graphrag/config/input_models/local_search_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class LocalSearchConfigInput(TypedDict): + """The default configuration section for Cache.""" + + text_unit_prop: NotRequired[float | str | None] + community_prop: NotRequired[float | str | None] + conversation_history_max_turns: NotRequired[int | str | None] + top_k_entities: NotRequired[int | str | None] + top_k_relationships: NotRequired[int | str | None] + max_tokens: NotRequired[int | str | None] + llm_max_tokens: NotRequired[int | str | None] diff --git a/func-app/graphrag/config/input_models/parallelization_parameters_input.py b/func-app/graphrag/config/input_models/parallelization_parameters_input.py new file mode 100644 index 0000000000..e9204437b2 --- /dev/null +++ b/func-app/graphrag/config/input_models/parallelization_parameters_input.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Parameters model.""" + +from typing_extensions import NotRequired, TypedDict + + +class ParallelizationParametersInput(TypedDict): + """LLM Parameters model.""" + + stagger: NotRequired[float | str | None] + num_threads: NotRequired[int | str | None] diff --git a/func-app/graphrag/config/input_models/query_context_config_input.py b/func-app/graphrag/config/input_models/query_context_config_input.py new file mode 100644 index 0000000000..c8f3d2e783 --- /dev/null +++ b/func-app/graphrag/config/input_models/query_context_config_input.py @@ -0,0 +1,7 @@ +from typing_extensions import NotRequired, TypedDict + +class QueryContextConfigInput(TypedDict): + """The default configuration section for Cache.""" + + files: NotRequired[str] + """The root path to run query on.""" diff --git a/func-app/graphrag/config/input_models/reporting_config_input.py b/func-app/graphrag/config/input_models/reporting_config_input.py new file mode 100644 index 0000000000..a224f0b440 --- /dev/null +++ b/func-app/graphrag/config/input_models/reporting_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + +from graphrag.config.enums import ReportingType + + +class ReportingConfigInput(TypedDict): + """The default configuration section for Reporting.""" + + type: NotRequired[ReportingType | str | None] + base_dir: NotRequired[str | None] + connection_string: NotRequired[str | None] + container_name: NotRequired[str | None] + storage_account_blob_url: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/snapshots_config_input.py b/func-app/graphrag/config/input_models/snapshots_config_input.py new file mode 100644 index 0000000000..c20becb071 --- /dev/null +++ b/func-app/graphrag/config/input_models/snapshots_config_input.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class SnapshotsConfigInput(TypedDict): + """Configuration section for snapshots.""" + + graphml: NotRequired[bool | str | None] + raw_entities: NotRequired[bool | str | None] + top_level_nodes: NotRequired[bool | str | None] diff --git a/func-app/graphrag/config/input_models/storage_config_input.py b/func-app/graphrag/config/input_models/storage_config_input.py new file mode 100644 index 0000000000..cc5caf7952 --- /dev/null +++ b/func-app/graphrag/config/input_models/storage_config_input.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + +from graphrag.config.enums import StorageType + + +class StorageConfigInput(TypedDict): + """The default configuration section for Storage.""" + + type: NotRequired[StorageType | str | None] + base_dir: NotRequired[str | None] + connection_string: NotRequired[str | None] + container_name: NotRequired[str | None] + storage_account_blob_url: NotRequired[str | None] diff --git a/func-app/graphrag/config/input_models/summarize_descriptions_config_input.py b/func-app/graphrag/config/input_models/summarize_descriptions_config_input.py new file mode 100644 index 0000000000..6ce756e558 --- /dev/null +++ b/func-app/graphrag/config/input_models/summarize_descriptions_config_input.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from .llm_config_input import LLMConfigInput + + +class SummarizeDescriptionsConfigInput(LLMConfigInput): + """Configuration section for description summarization.""" + + prompt: NotRequired[str | None] + max_length: NotRequired[int | str | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/text_embedding_config_input.py b/func-app/graphrag/config/input_models/text_embedding_config_input.py new file mode 100644 index 0000000000..a7e176c658 --- /dev/null +++ b/func-app/graphrag/config/input_models/text_embedding_config_input.py @@ -0,0 +1,23 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired + +from graphrag.config.enums import ( + TextEmbeddingTarget, +) + +from .llm_config_input import LLMConfigInput + + +class TextEmbeddingConfigInput(LLMConfigInput): + """Configuration section for text embeddings.""" + + batch_size: NotRequired[int | str | None] + batch_max_tokens: NotRequired[int | str | None] + target: NotRequired[TextEmbeddingTarget | str | None] + skip: NotRequired[list[str] | str | None] + vector_store: NotRequired[dict | None] + strategy: NotRequired[dict | None] diff --git a/func-app/graphrag/config/input_models/umap_config_input.py b/func-app/graphrag/config/input_models/umap_config_input.py new file mode 100644 index 0000000000..543ca385e0 --- /dev/null +++ b/func-app/graphrag/config/input_models/umap_config_input.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from typing_extensions import NotRequired, TypedDict + + +class UmapConfigInput(TypedDict): + """Configuration section for UMAP.""" + + enabled: NotRequired[bool | str | None] diff --git a/func-app/graphrag/config/models/__init__.py b/func-app/graphrag/config/models/__init__.py new file mode 100644 index 0000000000..f1d206ef85 --- /dev/null +++ b/func-app/graphrag/config/models/__init__.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Interfaces for Default Config parameterization.""" + +from .cache_config import CacheConfig +from .chunking_config import ChunkingConfig +from .claim_extraction_config import ClaimExtractionConfig +from .cluster_graph_config import ClusterGraphConfig +from .community_reports_config import CommunityReportsConfig +from .embed_graph_config import EmbedGraphConfig +from .entity_extraction_config import EntityExtractionConfig +from .global_search_config import GlobalSearchConfig +from .graph_rag_config import GraphRagConfig +from .input_config import InputConfig +from .llm_config import LLMConfig +from .llm_parameters import LLMParameters +from .local_search_config import LocalSearchConfig +from .parallelization_parameters import ParallelizationParameters +from .query_context_config import QueryContextConfig +from .reporting_config import ReportingConfig +from .snapshots_config import SnapshotsConfig +from .storage_config import StorageConfig +from .summarize_descriptions_config import SummarizeDescriptionsConfig +from .text_embedding_config import TextEmbeddingConfig +from .umap_config import UmapConfig +from .graphdb_config import GraphDBConfig + +__all__ = [ + "CacheConfig", + "ChunkingConfig", + "ClaimExtractionConfig", + "ClusterGraphConfig", + "CommunityReportsConfig", + "EmbedGraphConfig", + "EntityExtractionConfig", + "GlobalSearchConfig", + "GraphRagConfig", + "InputConfig", + "LLMConfig", + "LLMParameters", + "LocalSearchConfig", + "ParallelizationParameters", + "QueryContextConfig", + "ReportingConfig", + "SnapshotsConfig", + "StorageConfig", + "SummarizeDescriptionsConfig", + "TextEmbeddingConfig", + "UmapConfig", + "GraphDBConfig", +] diff --git a/func-app/graphrag/config/models/cache_config.py b/func-app/graphrag/config/models/cache_config.py new file mode 100644 index 0000000000..4589edce0b --- /dev/null +++ b/func-app/graphrag/config/models/cache_config.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import CacheType + + +class CacheConfig(BaseModel): + """The default configuration section for Cache.""" + + type: CacheType = Field( + description="The cache type to use.", default=defs.CACHE_TYPE + ) + base_dir: str = Field( + description="The base directory for the cache.", default=defs.CACHE_BASE_DIR + ) + connection_string: str | None = Field( + description="The cache connection string to use.", default=None + ) + container_name: str | None = Field( + description="The cache container name to use.", default=None + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", default=None + ) diff --git a/func-app/graphrag/config/models/chunking_config.py b/func-app/graphrag/config/models/chunking_config.py new file mode 100644 index 0000000000..4ca8a8d38c --- /dev/null +++ b/func-app/graphrag/config/models/chunking_config.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class ChunkingConfig(BaseModel): + """Configuration section for chunking.""" + + size: int = Field(description="The chunk size to use.", default=defs.CHUNK_SIZE) + overlap: int = Field( + description="The chunk overlap to use.", default=defs.CHUNK_OVERLAP + ) + group_by_columns: list[str] = Field( + description="The chunk by columns to use.", + default=defs.CHUNK_GROUP_BY_COLUMNS, + ) + strategy: dict | None = Field( + description="The chunk strategy to use, overriding the default tokenization strategy", + default=None, + ) + encoding_model: str | None = Field( + default=None, description="The encoding model to use." + ) + + def resolved_strategy(self, encoding_model: str) -> dict: + """Get the resolved chunking strategy.""" + from graphrag.index.verbs.text.chunk import ChunkStrategyType + + return self.strategy or { + "type": ChunkStrategyType.tokens, + "chunk_size": self.size, + "chunk_overlap": self.overlap, + "group_by_columns": self.group_by_columns, + "encoding_name": self.encoding_model or encoding_model, + } diff --git a/func-app/graphrag/config/models/claim_extraction_config.py b/func-app/graphrag/config/models/claim_extraction_config.py new file mode 100644 index 0000000000..a26fdad26e --- /dev/null +++ b/func-app/graphrag/config/models/claim_extraction_config.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pathlib import Path + +from pydantic import Field + +import graphrag.config.defaults as defs + +from .llm_config import LLMConfig + + +class ClaimExtractionConfig(LLMConfig): + """Configuration section for claim extraction.""" + + enabled: bool = Field( + description="Whether claim extraction is enabled.", + ) + prompt: str | None = Field( + description="The claim extraction prompt to use.", default=None + ) + description: str = Field( + description="The claim description to use.", + default=defs.CLAIM_DESCRIPTION, + ) + max_gleanings: int = Field( + description="The maximum number of entity gleanings to use.", + default=defs.CLAIM_MAX_GLEANINGS, + ) + strategy: dict | None = Field( + description="The override strategy to use.", default=None + ) + encoding_model: str | None = Field( + default=None, description="The encoding model to use." + ) + + def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: + """Get the resolved claim extraction strategy.""" + from graphrag.index.verbs.covariates.extract_covariates import ( + ExtractClaimsStrategyType, + ) + + return self.strategy or { + "type": ExtractClaimsStrategyType.graph_intelligence, + "llm": self.llm.model_dump(), + **self.parallelization.model_dump(), + "extraction_prompt": (Path(root_dir) / self.prompt) + .read_bytes() + .decode(encoding="utf-8") + if self.prompt + else None, + "claim_description": self.description, + "max_gleanings": self.max_gleanings, + "encoding_name": self.encoding_model or encoding_model, + } diff --git a/func-app/graphrag/config/models/cluster_graph_config.py b/func-app/graphrag/config/models/cluster_graph_config.py new file mode 100644 index 0000000000..3029baebcb --- /dev/null +++ b/func-app/graphrag/config/models/cluster_graph_config.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class ClusterGraphConfig(BaseModel): + """Configuration section for clustering graphs.""" + + max_cluster_size: int = Field( + description="The maximum cluster size to use.", default=defs.MAX_CLUSTER_SIZE + ) + strategy: dict | None = Field( + description="The cluster strategy to use.", default=None + ) + + def resolved_strategy(self) -> dict: + """Get the resolved cluster strategy.""" + from graphrag.index.verbs.graph.clustering import GraphCommunityStrategyType + + return self.strategy or { + "type": GraphCommunityStrategyType.leiden, + "max_cluster_size": self.max_cluster_size, + } diff --git a/func-app/graphrag/config/models/community_reports_config.py b/func-app/graphrag/config/models/community_reports_config.py new file mode 100644 index 0000000000..ab55063cec --- /dev/null +++ b/func-app/graphrag/config/models/community_reports_config.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pathlib import Path + +from pydantic import Field + +import graphrag.config.defaults as defs + +from .llm_config import LLMConfig + + +class CommunityReportsConfig(LLMConfig): + """Configuration section for community reports.""" + + prompt: str | None = Field( + description="The community report extraction prompt to use.", default=None + ) + max_length: int = Field( + description="The community report maximum length in tokens.", + default=defs.COMMUNITY_REPORT_MAX_LENGTH, + ) + max_input_length: int = Field( + description="The maximum input length in tokens to use when generating reports.", + default=defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH, + ) + strategy: dict | None = Field( + description="The override strategy to use.", default=None + ) + + def resolved_strategy(self, root_dir) -> dict: + """Get the resolved community report extraction strategy.""" + from graphrag.index.verbs.graph.report import CreateCommunityReportsStrategyType + + return self.strategy or { + "type": CreateCommunityReportsStrategyType.graph_intelligence, + "llm": self.llm.model_dump(), + **self.parallelization.model_dump(), + "extraction_prompt": (Path(root_dir) / self.prompt) + .read_bytes() + .decode(encoding="utf-8") + if self.prompt + else None, + "max_report_length": self.max_length, + "max_input_length": self.max_input_length, + } diff --git a/func-app/graphrag/config/models/embed_graph_config.py b/func-app/graphrag/config/models/embed_graph_config.py new file mode 100644 index 0000000000..8b7677ab10 --- /dev/null +++ b/func-app/graphrag/config/models/embed_graph_config.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class EmbedGraphConfig(BaseModel): + """The default configuration section for Node2Vec.""" + + enabled: bool = Field( + description="A flag indicating whether to enable node2vec.", + default=defs.NODE2VEC_ENABLED, + ) + num_walks: int = Field( + description="The node2vec number of walks.", default=defs.NODE2VEC_NUM_WALKS + ) + walk_length: int = Field( + description="The node2vec walk length.", default=defs.NODE2VEC_WALK_LENGTH + ) + window_size: int = Field( + description="The node2vec window size.", default=defs.NODE2VEC_WINDOW_SIZE + ) + iterations: int = Field( + description="The node2vec iterations.", default=defs.NODE2VEC_ITERATIONS + ) + random_seed: int = Field( + description="The node2vec random seed.", default=defs.NODE2VEC_RANDOM_SEED + ) + strategy: dict | None = Field( + description="The graph embedding strategy override.", default=None + ) + + def resolved_strategy(self) -> dict: + """Get the resolved node2vec strategy.""" + from graphrag.index.verbs.graph.embed import EmbedGraphStrategyType + + return self.strategy or { + "type": EmbedGraphStrategyType.node2vec, + "num_walks": self.num_walks, + "walk_length": self.walk_length, + "window_size": self.window_size, + "iterations": self.iterations, + "random_seed": self.iterations, + } diff --git a/func-app/graphrag/config/models/entity_extraction_config.py b/func-app/graphrag/config/models/entity_extraction_config.py new file mode 100644 index 0000000000..ca160bc4e2 --- /dev/null +++ b/func-app/graphrag/config/models/entity_extraction_config.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pathlib import Path + +from pydantic import Field + +import graphrag.config.defaults as defs + +from .llm_config import LLMConfig + + +class EntityExtractionConfig(LLMConfig): + """Configuration section for entity extraction.""" + + prompt: str | None = Field( + description="The entity extraction prompt to use.", default=None + ) + entity_types: list[str] = Field( + description="The entity extraction entity types to use.", + default=defs.ENTITY_EXTRACTION_ENTITY_TYPES, + ) + max_gleanings: int = Field( + description="The maximum number of entity gleanings to use.", + default=defs.ENTITY_EXTRACTION_MAX_GLEANINGS, + ) + strategy: dict | None = Field( + description="Override the default entity extraction strategy", default=None + ) + encoding_model: str | None = Field( + default=None, description="The encoding model to use." + ) + + def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict: + """Get the resolved entity extraction strategy.""" + from graphrag.index.verbs.entities.extraction import ExtractEntityStrategyType + + return self.strategy or { + "type": ExtractEntityStrategyType.graph_intelligence, + "llm": self.llm.model_dump(), + **self.parallelization.model_dump(), + "extraction_prompt": (Path(root_dir) / self.prompt) + .read_bytes() + .decode(encoding="utf-8") + if self.prompt + else None, + "max_gleanings": self.max_gleanings, + # It's prechunked in create_base_text_units + "encoding_name": self.encoding_model or encoding_model, + "prechunked": True, + } diff --git a/func-app/graphrag/config/models/global_search_config.py b/func-app/graphrag/config/models/global_search_config.py new file mode 100644 index 0000000000..9eb388c373 --- /dev/null +++ b/func-app/graphrag/config/models/global_search_config.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class GlobalSearchConfig(BaseModel): + """The default configuration section for Cache.""" + + temperature: float | None = Field( + description="The temperature to use for token generation.", + default=defs.GLOBAL_SEARCH_LLM_TEMPERATURE, + ) + top_p: float | None = Field( + description="The top-p value to use for token generation.", + default=defs.GLOBAL_SEARCH_LLM_TOP_P, + ) + n: int | None = Field( + description="The number of completions to generate.", + default=defs.GLOBAL_SEARCH_LLM_N, + ) + max_tokens: int = Field( + description="The maximum context size in tokens.", + default=defs.GLOBAL_SEARCH_MAX_TOKENS, + ) + data_max_tokens: int = Field( + description="The data llm maximum tokens.", + default=defs.GLOBAL_SEARCH_DATA_MAX_TOKENS, + ) + map_max_tokens: int = Field( + description="The map llm maximum tokens.", + default=defs.GLOBAL_SEARCH_MAP_MAX_TOKENS, + ) + reduce_max_tokens: int = Field( + description="The reduce llm maximum tokens.", + default=defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS, + ) + concurrency: int = Field( + description="The number of concurrent requests.", + default=defs.GLOBAL_SEARCH_CONCURRENCY, + ) diff --git a/func-app/graphrag/config/models/graph_rag_config.py b/func-app/graphrag/config/models/graph_rag_config.py new file mode 100644 index 0000000000..e7249a9016 --- /dev/null +++ b/func-app/graphrag/config/models/graph_rag_config.py @@ -0,0 +1,158 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from devtools import pformat +from graphrag.config.models.graphdb_config import GraphDBConfig +from pydantic import Field + +import graphrag.config.defaults as defs + +from .cache_config import CacheConfig +from .chunking_config import ChunkingConfig +from .claim_extraction_config import ClaimExtractionConfig +from .cluster_graph_config import ClusterGraphConfig +from .community_reports_config import CommunityReportsConfig +from .embed_graph_config import EmbedGraphConfig +from .entity_extraction_config import EntityExtractionConfig +from .global_search_config import GlobalSearchConfig +from .input_config import InputConfig +from .llm_config import LLMConfig +from .local_search_config import LocalSearchConfig +from .query_context_config import QueryContextConfig +from .reporting_config import ReportingConfig +from .snapshots_config import SnapshotsConfig +from .storage_config import StorageConfig +from .summarize_descriptions_config import ( + SummarizeDescriptionsConfig, +) +from .text_embedding_config import TextEmbeddingConfig +from .umap_config import UmapConfig + + +class GraphRagConfig(LLMConfig): + """Base class for the Default-Configuration parameterization settings.""" + + def __repr__(self) -> str: + """Get a string representation.""" + return pformat(self, highlight=False) + + def __str__(self): + """Get a string representation.""" + return self.model_dump_json(indent=4) + + root_dir: str = Field( + description="The root directory for the configuration.", default=None + ) + + reporting: ReportingConfig = Field( + description="The reporting configuration.", default=ReportingConfig() + ) + """The reporting configuration.""" + + storage: StorageConfig = Field( + description="The storage configuration.", default=StorageConfig() + ) + """The storage configuration.""" + + cache: CacheConfig = Field( + description="The cache configuration.", default=CacheConfig() + ) + """The cache configuration.""" + + input: InputConfig = Field( + description="The input configuration.", default=InputConfig() + ) + """The input configuration.""" + + embed_graph: EmbedGraphConfig = Field( + description="Graph embedding configuration.", + default=EmbedGraphConfig(), + ) + """Graph Embedding configuration.""" + + embeddings: TextEmbeddingConfig = Field( + description="The embeddings LLM configuration to use.", + default=TextEmbeddingConfig(), + ) + """The embeddings LLM configuration to use.""" + + chunks: ChunkingConfig = Field( + description="The chunking configuration to use.", + default=ChunkingConfig(), + ) + """The chunking configuration to use.""" + + snapshots: SnapshotsConfig = Field( + description="The snapshots configuration to use.", + default=SnapshotsConfig(), + ) + """The snapshots configuration to use.""" + + entity_extraction: EntityExtractionConfig = Field( + description="The entity extraction configuration to use.", + default=EntityExtractionConfig(), + ) + """The entity extraction configuration to use.""" + + summarize_descriptions: SummarizeDescriptionsConfig = Field( + description="The description summarization configuration to use.", + default=SummarizeDescriptionsConfig(), + ) + """The description summarization configuration to use.""" + + community_reports: CommunityReportsConfig = Field( + description="The community reports configuration to use.", + default=CommunityReportsConfig(), + ) + """The community reports configuration to use.""" + + claim_extraction: ClaimExtractionConfig = Field( + description="The claim extraction configuration to use.", + default=ClaimExtractionConfig( + enabled=defs.CLAIM_EXTRACTION_ENABLED, + ), + ) + """The claim extraction configuration to use.""" + + cluster_graph: ClusterGraphConfig = Field( + description="The cluster graph configuration to use.", + default=ClusterGraphConfig(), + ) + """The cluster graph configuration to use.""" + + umap: UmapConfig = Field( + description="The UMAP configuration to use.", default=UmapConfig() + ) + """The UMAP configuration to use.""" + + local_search: LocalSearchConfig = Field( + description="The local search configuration.", default=LocalSearchConfig() + ) + """The local search configuration.""" + + global_search: GlobalSearchConfig = Field( + description="The global search configuration.", default=GlobalSearchConfig() + ) + """The global search configuration.""" + + encoding_model: str = Field( + description="The encoding model to use.", default=defs.ENCODING_MODEL + ) + """The encoding model to use.""" + + skip_workflows: list[str] = Field( + description="The workflows to skip, usually for testing reasons.", default=[] + ) + """The workflows to skip, usually for testing reasons.""" + + query_context: QueryContextConfig = Field( + description="The query context to use.", default=[] + ) + """The query context to use.""" + + graphdb: GraphDBConfig = Field( + description="The parameters to use graphdb.", default=[] + ) + """The parameters to use graphdb.""" \ No newline at end of file diff --git a/func-app/graphrag/config/models/graphdb_config.py b/func-app/graphrag/config/models/graphdb_config.py new file mode 100644 index 0000000000..8ee0f9d276 --- /dev/null +++ b/func-app/graphrag/config/models/graphdb_config.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class GraphDBConfig(BaseModel): + account_name: str|None = Field( + description="Graphdb account name", + default=None + ) + account_key: str|None = Field( + description="Graphdb account key", + default=None + ) + username: str|None = Field( + description="Graphdb username", + default=None + ) + enabled: bool = Field( + description="Flag to enable querying into graphdb", + default=False + ) + + cosmos_url: str|None = Field( + description="Cosmos account url", + default=None, + ) + + gremlin_url: str|None = Field( + description="Gremlin db url", + default=None, + ) \ No newline at end of file diff --git a/func-app/graphrag/config/models/input_config.py b/func-app/graphrag/config/models/input_config.py new file mode 100644 index 0000000000..f9e5847af6 --- /dev/null +++ b/func-app/graphrag/config/models/input_config.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import InputFileType, InputType + + +class InputConfig(BaseModel): + """The default configuration section for Input.""" + + type: InputType = Field( + description="The input type to use.", default=defs.INPUT_TYPE + ) + file_type: InputFileType = Field( + description="The input file type to use.", default=defs.INPUT_FILE_TYPE + ) + base_dir: str = Field( + description="The input base directory to use.", default=defs.INPUT_BASE_DIR + ) + connection_string: str | None = Field( + description="The azure blob storage connection string to use.", default=None + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", default=None + ) + container_name: str | None = Field( + description="The azure blob storage container name to use.", default=None + ) + encoding: str | None = Field( + description="The input file encoding to use.", + default=defs.INPUT_FILE_ENCODING, + ) + file_pattern: str = Field( + description="The input file pattern to use.", default=defs.INPUT_TEXT_PATTERN + ) + file_filter: dict[str, str] | None = Field( + description="The optional file filter for the input files.", default=None + ) + source_column: str | None = Field( + description="The input source column to use.", default=None + ) + timestamp_column: str | None = Field( + description="The input timestamp column to use.", default=None + ) + timestamp_format: str | None = Field( + description="The input timestamp format to use.", default=None + ) + text_column: str = Field( + description="The input text column to use.", default=defs.INPUT_TEXT_COLUMN + ) + title_column: str | None = Field( + description="The input title column to use.", default=None + ) + document_attribute_columns: list[str] = Field( + description="The document attribute columns to use.", default=[] + ) diff --git a/func-app/graphrag/config/models/llm_config.py b/func-app/graphrag/config/models/llm_config.py new file mode 100644 index 0000000000..62c193b0c5 --- /dev/null +++ b/func-app/graphrag/config/models/llm_config.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from datashaper import AsyncType +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + +from .llm_parameters import LLMParameters +from .parallelization_parameters import ParallelizationParameters + + +class LLMConfig(BaseModel): + """Base class for LLM-configured steps.""" + + llm: LLMParameters = Field( + description="The LLM configuration to use.", default=LLMParameters() + ) + parallelization: ParallelizationParameters = Field( + description="The parallelization configuration to use.", + default=ParallelizationParameters(), + ) + async_mode: AsyncType = Field( + description="The async mode to use.", default=defs.ASYNC_MODE + ) diff --git a/func-app/graphrag/config/models/llm_parameters.py b/func-app/graphrag/config/models/llm_parameters.py new file mode 100644 index 0000000000..df81138a2f --- /dev/null +++ b/func-app/graphrag/config/models/llm_parameters.py @@ -0,0 +1,87 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Parameters model.""" + +from pydantic import BaseModel, ConfigDict, Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import LLMType + + +class LLMParameters(BaseModel): + """LLM Parameters model.""" + + model_config = ConfigDict(protected_namespaces=(), extra="allow") + api_key: str | None = Field( + description="The API key to use for the LLM service.", + default=None, + ) + type: LLMType = Field( + description="The type of LLM model to use.", default=defs.LLM_TYPE + ) + model: str = Field(description="The LLM model to use.", default=defs.LLM_MODEL) + max_tokens: int | None = Field( + description="The maximum number of tokens to generate.", + default=defs.LLM_MAX_TOKENS, + ) + temperature: float | None = Field( + description="The temperature to use for token generation.", + default=defs.LLM_TEMPERATURE, + ) + top_p: float | None = Field( + description="The top-p value to use for token generation.", + default=defs.LLM_TOP_P, + ) + n: int | None = Field( + description="The number of completions to generate.", + default=defs.LLM_N, + ) + request_timeout: float = Field( + description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT + ) + api_base: str | None = Field( + description="The base URL for the LLM API.", default=None + ) + api_version: str | None = Field( + description="The version of the LLM API to use.", default=None + ) + organization: str | None = Field( + description="The organization to use for the LLM service.", default=None + ) + proxy: str | None = Field( + description="The proxy to use for the LLM service.", default=None + ) + cognitive_services_endpoint: str | None = Field( + description="The endpoint to reach cognitives services.", default=None + ) + deployment_name: str | None = Field( + description="The deployment name to use for the LLM service.", default=None + ) + model_supports_json: bool | None = Field( + description="Whether the model supports JSON output mode.", default=None + ) + tokens_per_minute: int = Field( + description="The number of tokens per minute to use for the LLM service.", + default=defs.LLM_TOKENS_PER_MINUTE, + ) + requests_per_minute: int = Field( + description="The number of requests per minute to use for the LLM service.", + default=defs.LLM_REQUESTS_PER_MINUTE, + ) + max_retries: int = Field( + description="The maximum number of retries to use for the LLM service.", + default=defs.LLM_MAX_RETRIES, + ) + max_retry_wait: float = Field( + description="The maximum retry wait to use for the LLM service.", + default=defs.LLM_MAX_RETRY_WAIT, + ) + sleep_on_rate_limit_recommendation: bool = Field( + description="Whether to sleep on rate limit recommendations.", + default=defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION, + ) + concurrent_requests: int = Field( + description="Whether to use concurrent requests for the LLM service.", + default=defs.LLM_CONCURRENT_REQUESTS, + ) diff --git a/func-app/graphrag/config/models/local_search_config.py b/func-app/graphrag/config/models/local_search_config.py new file mode 100644 index 0000000000..c41344daef --- /dev/null +++ b/func-app/graphrag/config/models/local_search_config.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class LocalSearchConfig(BaseModel): + """The default configuration section for Cache.""" + + text_unit_prop: float = Field( + description="The text unit proportion.", + default=defs.LOCAL_SEARCH_TEXT_UNIT_PROP, + ) + community_prop: float = Field( + description="The community proportion.", + default=defs.LOCAL_SEARCH_COMMUNITY_PROP, + ) + conversation_history_max_turns: int = Field( + description="The conversation history maximum turns.", + default=defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS, + ) + top_k_entities: int = Field( + description="The top k mapped entities.", + default=defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES, + ) + top_k_relationships: int = Field( + description="The top k mapped relations.", + default=defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS, + ) + temperature: float | None = Field( + description="The temperature to use for token generation.", + default=defs.LOCAL_SEARCH_LLM_TEMPERATURE, + ) + top_p: float | None = Field( + description="The top-p value to use for token generation.", + default=defs.LOCAL_SEARCH_LLM_TOP_P, + ) + n: int | None = Field( + description="The number of completions to generate.", + default=defs.LOCAL_SEARCH_LLM_N, + ) + max_tokens: int = Field( + description="The maximum tokens.", default=defs.LOCAL_SEARCH_MAX_TOKENS + ) + llm_max_tokens: int = Field( + description="The LLM maximum tokens.", default=defs.LOCAL_SEARCH_LLM_MAX_TOKENS + ) diff --git a/func-app/graphrag/config/models/parallelization_parameters.py b/func-app/graphrag/config/models/parallelization_parameters.py new file mode 100644 index 0000000000..80a85b8639 --- /dev/null +++ b/func-app/graphrag/config/models/parallelization_parameters.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Parameters model.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class ParallelizationParameters(BaseModel): + """LLM Parameters model.""" + + stagger: float = Field( + description="The stagger to use for the LLM service.", + default=defs.PARALLELIZATION_STAGGER, + ) + num_threads: int = Field( + description="The number of threads to use for the LLM service.", + default=defs.PARALLELIZATION_NUM_THREADS, + ) diff --git a/func-app/graphrag/config/models/query_context_config.py b/func-app/graphrag/config/models/query_context_config.py new file mode 100644 index 0000000000..15626efba9 --- /dev/null +++ b/func-app/graphrag/config/models/query_context_config.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class QueryContextConfig(BaseModel): + """The default configuration section for Cache.""" + files: list[str] = Field( + description="The list of the files on which query should be run.", + default=[] + ) \ No newline at end of file diff --git a/func-app/graphrag/config/models/reporting_config.py b/func-app/graphrag/config/models/reporting_config.py new file mode 100644 index 0000000000..35e86cf5da --- /dev/null +++ b/func-app/graphrag/config/models/reporting_config.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import ReportingType + + +class ReportingConfig(BaseModel): + """The default configuration section for Reporting.""" + + type: ReportingType = Field( + description="The reporting type to use.", default=defs.REPORTING_TYPE + ) + base_dir: str = Field( + description="The base directory for reporting.", + default=defs.REPORTING_BASE_DIR, + ) + connection_string: str | None = Field( + description="The reporting connection string to use.", default=None + ) + container_name: str | None = Field( + description="The reporting container name to use.", default=None + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", default=None + ) diff --git a/func-app/graphrag/config/models/snapshots_config.py b/func-app/graphrag/config/models/snapshots_config.py new file mode 100644 index 0000000000..08293fb7a7 --- /dev/null +++ b/func-app/graphrag/config/models/snapshots_config.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class SnapshotsConfig(BaseModel): + """Configuration section for snapshots.""" + + graphml: bool = Field( + description="A flag indicating whether to take snapshots of GraphML.", + default=defs.SNAPSHOTS_GRAPHML, + ) + raw_entities: bool = Field( + description="A flag indicating whether to take snapshots of raw entities.", + default=defs.SNAPSHOTS_RAW_ENTITIES, + ) + top_level_nodes: bool = Field( + description="A flag indicating whether to take snapshots of top-level nodes.", + default=defs.SNAPSHOTS_TOP_LEVEL_NODES, + ) diff --git a/func-app/graphrag/config/models/storage_config.py b/func-app/graphrag/config/models/storage_config.py new file mode 100644 index 0000000000..b3b5c70fe0 --- /dev/null +++ b/func-app/graphrag/config/models/storage_config.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import StorageType + + +class StorageConfig(BaseModel): + """The default configuration section for Storage.""" + + type: StorageType = Field( + description="The storage type to use.", default=defs.STORAGE_TYPE + ) + base_dir: str = Field( + description="The base directory for the storage.", + default=defs.STORAGE_BASE_DIR, + ) + connection_string: str | None = Field( + description="The storage connection string to use.", default=None + ) + container_name: str | None = Field( + description="The storage container name to use.", default=None + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", default=None + ) + overwrite: bool = Field( + description="If true, don't throw error overwrite existing containers otherwise throw error", default= False + ) diff --git a/func-app/graphrag/config/models/summarize_descriptions_config.py b/func-app/graphrag/config/models/summarize_descriptions_config.py new file mode 100644 index 0000000000..9747d949c6 --- /dev/null +++ b/func-app/graphrag/config/models/summarize_descriptions_config.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pathlib import Path + +from pydantic import Field + +import graphrag.config.defaults as defs + +from .llm_config import LLMConfig + + +class SummarizeDescriptionsConfig(LLMConfig): + """Configuration section for description summarization.""" + + prompt: str | None = Field( + description="The description summarization prompt to use.", default=None + ) + max_length: int = Field( + description="The description summarization maximum length.", + default=defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH, + ) + strategy: dict | None = Field( + description="The override strategy to use.", default=None + ) + + def resolved_strategy(self, root_dir: str) -> dict: + """Get the resolved description summarization strategy.""" + from graphrag.index.verbs.entities.summarize import SummarizeStrategyType + + return self.strategy or { + "type": SummarizeStrategyType.graph_intelligence, + "llm": self.llm.model_dump(), + **self.parallelization.model_dump(), + "summarize_prompt": (Path(root_dir) / self.prompt) + .read_bytes() + .decode(encoding="utf-8") + if self.prompt + else None, + "max_summary_length": self.max_length, + } diff --git a/func-app/graphrag/config/models/text_embedding_config.py b/func-app/graphrag/config/models/text_embedding_config.py new file mode 100644 index 0000000000..5c2fcdb86e --- /dev/null +++ b/func-app/graphrag/config/models/text_embedding_config.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import Field + +import graphrag.config.defaults as defs +from graphrag.config.enums import TextEmbeddingTarget + +from .llm_config import LLMConfig + + +class TextEmbeddingConfig(LLMConfig): + """Configuration section for text embeddings.""" + + batch_size: int = Field( + description="The batch size to use.", default=defs.EMBEDDING_BATCH_SIZE + ) + batch_max_tokens: int = Field( + description="The batch max tokens to use.", + default=defs.EMBEDDING_BATCH_MAX_TOKENS, + ) + target: TextEmbeddingTarget = Field( + description="The target to use. 'all' or 'required'.", + default=defs.EMBEDDING_TARGET, + ) + skip: list[str] = Field(description="The specific embeddings to skip.", default=[]) + vector_store: dict | None = Field( + description="The vector storage configuration", default=None + ) + strategy: dict | None = Field( + description="The override strategy to use.", default=None + ) + + def resolved_strategy(self) -> dict: + """Get the resolved text embedding strategy.""" + from graphrag.index.verbs.text.embed import TextEmbedStrategyType + + return self.strategy or { + "type": TextEmbedStrategyType.openai, + "llm": self.llm.model_dump(), + **self.parallelization.model_dump(), + "batch_size": self.batch_size, + "batch_max_tokens": self.batch_max_tokens, + } diff --git a/func-app/graphrag/config/models/umap_config.py b/func-app/graphrag/config/models/umap_config.py new file mode 100644 index 0000000000..1d9bd93ead --- /dev/null +++ b/func-app/graphrag/config/models/umap_config.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +import graphrag.config.defaults as defs + + +class UmapConfig(BaseModel): + """Configuration section for UMAP.""" + + enabled: bool = Field( + description="A flag indicating whether to enable UMAP.", + default=defs.UMAP_ENABLED, + ) diff --git a/func-app/graphrag/config/read_dotenv.py b/func-app/graphrag/config/read_dotenv.py new file mode 100644 index 0000000000..7e041757b3 --- /dev/null +++ b/func-app/graphrag/config/read_dotenv.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the read_dotenv utility.""" + +import logging +import os +from pathlib import Path + +from dotenv import dotenv_values + +log = logging.getLogger(__name__) + + +def read_dotenv(root: str) -> None: + """Read a .env file in the given root path.""" + env_path = Path(root) / ".env" + if env_path.exists(): + log.info("Loading pipeline .env file") + env_config = dotenv_values(f"{env_path}") + for key, value in env_config.items(): + if key not in os.environ: + os.environ[key] = value or "" + else: + log.info("No .env file found at %s", root) diff --git a/func-app/graphrag/index/__init__.py b/func-app/graphrag/index/__init__.py new file mode 100644 index 0000000000..38ab263620 --- /dev/null +++ b/func-app/graphrag/index/__init__.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine package root.""" + +from .cache import PipelineCache +from .config import ( + PipelineBlobCacheConfig, + PipelineBlobReportingConfig, + PipelineBlobStorageConfig, + PipelineCacheConfig, + PipelineCacheConfigTypes, + PipelineConfig, + PipelineConsoleReportingConfig, + PipelineCSVInputConfig, + PipelineFileCacheConfig, + PipelineFileReportingConfig, + PipelineFileStorageConfig, + PipelineInputConfig, + PipelineInputConfigTypes, + PipelineMemoryCacheConfig, + PipelineMemoryStorageConfig, + PipelineNoneCacheConfig, + PipelineReportingConfig, + PipelineReportingConfigTypes, + PipelineStorageConfig, + PipelineStorageConfigTypes, + PipelineTextInputConfig, + PipelineWorkflowConfig, + PipelineWorkflowReference, + PipelineWorkflowStep, +) +from .create_pipeline_config import create_pipeline_config +from .errors import ( + NoWorkflowsDefinedError, + UndefinedWorkflowError, + UnknownWorkflowError, +) +from .load_pipeline_config import load_pipeline_config +from .run import run_pipeline, run_pipeline_with_config +from graphrag.common.storage import PipelineStorage + +__all__ = [ + "NoWorkflowsDefinedError", + "PipelineBlobCacheConfig", + "PipelineBlobCacheConfig", + "PipelineBlobReportingConfig", + "PipelineBlobStorageConfig", + "PipelineCSVInputConfig", + "PipelineCache", + "PipelineCacheConfig", + "PipelineCacheConfigTypes", + "PipelineConfig", + "PipelineConsoleReportingConfig", + "PipelineFileCacheConfig", + "PipelineFileReportingConfig", + "PipelineFileStorageConfig", + "PipelineInputConfig", + "PipelineInputConfigTypes", + "PipelineMemoryCacheConfig", + "PipelineMemoryStorageConfig", + "PipelineNoneCacheConfig", + "PipelineReportingConfig", + "PipelineReportingConfigTypes", + "PipelineStorage", + "PipelineStorageConfig", + "PipelineStorageConfigTypes", + "PipelineTextInputConfig", + "PipelineWorkflowConfig", + "PipelineWorkflowReference", + "PipelineWorkflowStep", + "UndefinedWorkflowError", + "UnknownWorkflowError", + "create_pipeline_config", + "load_pipeline_config", + "run_pipeline", + "run_pipeline_with_config", +] diff --git a/func-app/graphrag/index/__main__.py b/func-app/graphrag/index/__main__.py new file mode 100644 index 0000000000..de2c156a69 --- /dev/null +++ b/func-app/graphrag/index/__main__.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine package root.""" + +import argparse + +from .cli import index_cli + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + help="The configuration yaml file to use when running the pipeline", + required=False, + type=str, + ) + parser.add_argument( + "-v", + "--verbose", + help="Runs the pipeline with verbose logging", + action="store_true", + ) + parser.add_argument( + "--memprofile", + help="Runs the pipeline with memory profiling", + action="store_true", + ) + parser.add_argument( + "--root", + help="If no configuration is defined, the root directory to use for input data and output data. Default value: the current directory", + # Only required if config is not defined + required=False, + default=".", + type=str, + ) + parser.add_argument( + "--resume", + help="Resume a given data run leveraging Parquet output files.", + # Only required if config is not defined + required=False, + default=None, + type=str, + ) + parser.add_argument( + "--reporter", + help="The progress reporter to use. Valid values are 'rich', 'print', or 'none'", + type=str, + ) + parser.add_argument( + "--emit", + help="The data formats to emit, comma-separated. Valid values are 'parquet' and 'csv'. default='parquet,csv'", + type=str, + ) + parser.add_argument( + "--context_id", + required=False, + help="Context id to activate or deactivate.", + type=str + ) + parser.add_argument( + "--context_operation", + help="Context operation activate or deactivate.", + required=False, + # Only required if contextId is provided + type=str + ) + parser.add_argument( + "--dryrun", + help="Run the pipeline without actually executing any steps and inspect the configuration.", + action="store_true", + ) + parser.add_argument("--nocache", help="Disable LLM cache.", action="store_true") + parser.add_argument( + "--init", + help="Create an initial configuration in the given path.", + action="store_true", + ) + parser.add_argument( + "--overlay_defaults", + help="Overlay default configuration values on a provided configuration file (--config).", + action="store_true", + ) + parser.add_argument( + "--community_level", + help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities", + type=int, + default=2, + ) + parser.add_argument( + "--use_kusto_community_reports", + help="If enabled community reports are loaded into Kusto during activation", + action="store_true", + ) + parser.add_argument( + "--optimized_search", + help="Runs optimized search and export artifacts", + type=bool, + default=False, + ) + + args = parser.parse_args() + + if args.overlay_defaults and not args.config: + parser.error("--overlay-defaults requires --config") + + index_cli( + root=args.root, + verbose=args.verbose or False, + resume=args.resume, + memprofile=args.memprofile or False, + nocache=args.nocache or False, + reporter=args.reporter, + config=args.config, + emit=args.emit, + dryrun=args.dryrun or False, + init=args.init or False, + overlay_defaults=args.overlay_defaults or False, + cli=True, + context_id=args.context_id, + context_operation=args.context_operation, + community_level=args.community_level, + use_kusto_community_reports=args.use_kusto_community_reports, + optimized_search=args.optimized_search + ) diff --git a/func-app/graphrag/index/bootstrap.py b/func-app/graphrag/index/bootstrap.py new file mode 100644 index 0000000000..398ec88b20 --- /dev/null +++ b/func-app/graphrag/index/bootstrap.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Bootstrap definition.""" + +import warnings + +# Ignore warnings from numba +warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") +warnings.filterwarnings("ignore", message=".*Use no seed for parallelism.*") + +initialized_nltk = False + + +def bootstrap(): + """Bootstrap definition.""" + global initialized_nltk + if not initialized_nltk: + import nltk + from nltk.corpus import wordnet as wn + + nltk.download("punkt") + nltk.download("averaged_perceptron_tagger") + nltk.download("maxent_ne_chunker") + nltk.download("words") + nltk.download("wordnet") + wn.ensure_loaded() + initialized_nltk = True diff --git a/func-app/graphrag/index/cache/__init__.py b/func-app/graphrag/index/cache/__init__.py new file mode 100644 index 0000000000..42ebb22994 --- /dev/null +++ b/func-app/graphrag/index/cache/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine cache package root.""" + +from .json_pipeline_cache import JsonPipelineCache +from .load_cache import load_cache +from .memory_pipeline_cache import InMemoryCache +from .noop_pipeline_cache import NoopPipelineCache +from .pipeline_cache import PipelineCache + +__all__ = [ + "InMemoryCache", + "JsonPipelineCache", + "NoopPipelineCache", + "PipelineCache", + "load_cache", +] diff --git a/func-app/graphrag/index/cache/json_pipeline_cache.py b/func-app/graphrag/index/cache/json_pipeline_cache.py new file mode 100644 index 0000000000..30e73fedc6 --- /dev/null +++ b/func-app/graphrag/index/cache/json_pipeline_cache.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'FilePipelineCache' model.""" + +import json +from typing import Any + +from graphrag.common.storage.typing import PipelineStorage + +from .pipeline_cache import PipelineCache + + +class JsonPipelineCache(PipelineCache): + """File pipeline cache class definition.""" + + _storage: PipelineStorage + _encoding: str + + def __init__(self, storage: PipelineStorage, encoding="utf-8"): + """Init method definition.""" + self._storage = storage + self._encoding = encoding + + async def get(self, key: str) -> str | None: + """Get method definition.""" + if await self.has(key): + try: + data = await self._storage.get(key, encoding=self._encoding) + data = json.loads(data) + except UnicodeDecodeError: + await self._storage.delete(key) + return None + except json.decoder.JSONDecodeError: + await self._storage.delete(key) + return None + else: + return data.get("result") + + return None + + async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None: + """Set method definition.""" + if value is None: + return + data = {"result": value, **(debug_data or {})} + await self._storage.set(key, json.dumps(data), encoding=self._encoding) + + async def has(self, key: str) -> bool: + """Has method definition.""" + return await self._storage.has(key) + + async def delete(self, key: str) -> None: + """Delete method definition.""" + if await self.has(key): + await self._storage.delete(key) + + async def clear(self) -> None: + """Clear method definition.""" + await self._storage.clear() + + def child(self, name: str) -> "JsonPipelineCache": + """Child method definition.""" + return JsonPipelineCache(self._storage.child(name), encoding=self._encoding) diff --git a/func-app/graphrag/index/cache/load_cache.py b/func-app/graphrag/index/cache/load_cache.py new file mode 100644 index 0000000000..1a97b2e4de --- /dev/null +++ b/func-app/graphrag/index/cache/load_cache.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load_cache method definition.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from graphrag.config.enums import CacheType +from graphrag.index.config.cache import ( + PipelineBlobCacheConfig, + PipelineFileCacheConfig, +) +from graphrag.common.storage import BlobPipelineStorage, FilePipelineStorage + +if TYPE_CHECKING: + from graphrag.index.config import ( + PipelineCacheConfig, + ) + +from .json_pipeline_cache import JsonPipelineCache +from .memory_pipeline_cache import create_memory_cache +from .noop_pipeline_cache import NoopPipelineCache + + +def load_cache(config: PipelineCacheConfig | None, root_dir: str | None): + """Load the cache from the given config.""" + if config is None: + return NoopPipelineCache() + + match config.type: + case CacheType.none: + return NoopPipelineCache() + case CacheType.memory: + return create_memory_cache() + case CacheType.file: + config = cast(PipelineFileCacheConfig, config) + storage = FilePipelineStorage(root_dir).child(config.base_dir) + return JsonPipelineCache(storage) + case CacheType.blob: + config = cast(PipelineBlobCacheConfig, config) + storage = BlobPipelineStorage( + config.connection_string, + config.container_name, + storage_account_blob_url=config.storage_account_blob_url, + ).child(config.base_dir) + return JsonPipelineCache(storage) + case _: + msg = f"Unknown cache type: {config.type}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/cache/memory_pipeline_cache.py b/func-app/graphrag/index/cache/memory_pipeline_cache.py new file mode 100644 index 0000000000..fa42f3f921 --- /dev/null +++ b/func-app/graphrag/index/cache/memory_pipeline_cache.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'InMemoryCache' model.""" + +from typing import Any + +from .pipeline_cache import PipelineCache + + +class InMemoryCache(PipelineCache): + """In memory cache class definition.""" + + _cache: dict[str, Any] + _name: str + + def __init__(self, name: str | None = None): + """Init method definition.""" + self._cache = {} + self._name = name or "" + + async def get(self, key: str) -> Any: + """Get the value for the given key. + + Args: + - key - The key to get the value for. + - as_bytes - Whether or not to return the value as bytes. + + Returns + ------- + - output - The value for the given key. + """ + key = self._create_cache_key(key) + return self._cache.get(key) + + async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None: + """Set the value for the given key. + + Args: + - key - The key to set the value for. + - value - The value to set. + """ + key = self._create_cache_key(key) + self._cache[key] = value + + async def has(self, key: str) -> bool: + """Return True if the given key exists in the storage. + + Args: + - key - The key to check for. + + Returns + ------- + - output - True if the key exists in the storage, False otherwise. + """ + key = self._create_cache_key(key) + return key in self._cache + + async def delete(self, key: str) -> None: + """Delete the given key from the storage. + + Args: + - key - The key to delete. + """ + key = self._create_cache_key(key) + del self._cache[key] + + async def clear(self) -> None: + """Clear the storage.""" + self._cache.clear() + + def child(self, name: str) -> PipelineCache: + """Create a sub cache with the given name.""" + return InMemoryCache(name) + + def _create_cache_key(self, key: str) -> str: + """Create a cache key for the given key.""" + return f"{self._name}{key}" + + +def create_memory_cache() -> PipelineCache: + """Create a memory cache.""" + return InMemoryCache() diff --git a/func-app/graphrag/index/cache/noop_pipeline_cache.py b/func-app/graphrag/index/cache/noop_pipeline_cache.py new file mode 100644 index 0000000000..b7c3e60fdd --- /dev/null +++ b/func-app/graphrag/index/cache/noop_pipeline_cache.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Module containing the NoopPipelineCache implementation.""" + +from typing import Any + +from .pipeline_cache import PipelineCache + + +class NoopPipelineCache(PipelineCache): + """A no-op implementation of the pipeline cache, usually useful for testing.""" + + async def get(self, key: str) -> Any: + """Get the value for the given key. + + Args: + - key - The key to get the value for. + - as_bytes - Whether or not to return the value as bytes. + + Returns + ------- + - output - The value for the given key. + """ + return None + + async def set( + self, key: str, value: str | bytes | None, debug_data: dict | None = None + ) -> None: + """Set the value for the given key. + + Args: + - key - The key to set the value for. + - value - The value to set. + """ + + async def has(self, key: str) -> bool: + """Return True if the given key exists in the cache. + + Args: + - key - The key to check for. + + Returns + ------- + - output - True if the key exists in the cache, False otherwise. + """ + return False + + async def delete(self, key: str) -> None: + """Delete the given key from the cache. + + Args: + - key - The key to delete. + """ + + async def clear(self) -> None: + """Clear the cache.""" + + def child(self, name: str) -> PipelineCache: + """Create a child cache with the given name. + + Args: + - name - The name to create the sub cache with. + """ + return self diff --git a/func-app/graphrag/index/cache/pipeline_cache.py b/func-app/graphrag/index/cache/pipeline_cache.py new file mode 100644 index 0000000000..c68c5cfb4b --- /dev/null +++ b/func-app/graphrag/index/cache/pipeline_cache.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineCache' model.""" + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import Any + + +class PipelineCache(metaclass=ABCMeta): + """Provide a cache interface for the pipeline.""" + + @abstractmethod + async def get(self, key: str) -> Any: + """Get the value for the given key. + + Args: + - key - The key to get the value for. + - as_bytes - Whether or not to return the value as bytes. + + Returns + ------- + - output - The value for the given key. + """ + + @abstractmethod + async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None: + """Set the value for the given key. + + Args: + - key - The key to set the value for. + - value - The value to set. + """ + + @abstractmethod + async def has(self, key: str) -> bool: + """Return True if the given key exists in the cache. + + Args: + - key - The key to check for. + + Returns + ------- + - output - True if the key exists in the cache, False otherwise. + """ + + @abstractmethod + async def delete(self, key: str) -> None: + """Delete the given key from the cache. + + Args: + - key - The key to delete. + """ + + @abstractmethod + async def clear(self) -> None: + """Clear the cache.""" + + @abstractmethod + def child(self, name: str) -> PipelineCache: + """Create a child cache with the given name. + + Args: + - name - The name to create the sub cache with. + """ diff --git a/func-app/graphrag/index/cli.py b/func-app/graphrag/index/cli.py new file mode 100644 index 0000000000..2695f46af9 --- /dev/null +++ b/func-app/graphrag/index/cli.py @@ -0,0 +1,356 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Main definition.""" + +import asyncio +import json +import logging +import platform +import sys +import time +import warnings +from pathlib import Path + +from graphrag.config import ( + GraphRagConfig, + create_graphrag_config, +) +from graphrag.config.enums import ContextSwitchType +from graphrag.common.utils.common_utils import is_valid_guid +from graphrag.index import PipelineConfig, create_pipeline_config +from graphrag.index.cache import NoopPipelineCache +from graphrag.common.progress import ( + NullProgressReporter, + PrintProgressReporter, + ProgressReporter, +) +from graphrag.common.progress.rich import RichProgressReporter +from graphrag.index.run import run_pipeline_with_config + +from .emit import TableEmitterType +from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT +from .graph.extractors.community_reports.prompts import COMMUNITY_REPORT_PROMPT +from .graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT +from .graph.extractors.summarize.prompts import SUMMARIZE_PROMPT +from .init_content import INIT_DOTENV, INIT_YAML + +# Ignore warnings from numba +warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*") + +log = logging.getLogger(__name__) + +def redact(input: dict) -> str: + """Sanitize the config json.""" + + # Redact any sensitive configuration + def redact_dict(input: dict) -> dict: + if not isinstance(input, dict): + return input + + result = {} + for key, value in input.items(): + if key in { + "api_key", + "connection_string", + "container_name", + "organization", + }: + if value is not None: + result[key] = f"REDACTED, length {len(value)}" + elif isinstance(value, dict): + result[key] = redact_dict(value) + elif isinstance(value, list): + result[key] = [redact_dict(i) for i in value] + else: + result[key] = value + return result + + redacted_dict = redact_dict(input) + return json.dumps(redacted_dict, indent=4) + + +def index_cli( + root: str, + init: bool, + community_level: int, + context_operation: str | None, + context_id: str | None, + verbose: bool, + resume: str | None, + memprofile: bool, + nocache: bool, + config: str | None, + emit: str | None, + dryrun: bool, + overlay_defaults: bool, + cli: bool = False, + use_kusto_community_reports: bool = False, + optimized_search: bool = False, +): + """Run the pipeline with the given config.""" + root = Path(__file__).parent.parent.parent.__str__() + run_id = resume or time.strftime("%Y%m%d-%H%M%S") + _enable_logging(root, run_id, verbose) + progress_reporter = _get_progress_reporter("none") + _initialize_project_at(root, progress_reporter) + if overlay_defaults: + pipeline_config: str | PipelineConfig = _create_default_config( + root, config, verbose, dryrun or False, progress_reporter + ) + else: + pipeline_config: str | PipelineConfig = config or _create_default_config( + root, None, verbose, dryrun or False, progress_reporter + ) + + cache = NoopPipelineCache() if nocache else None + pipeline_emit = emit.split(",") if emit else None + encountered_errors = False + logging.info("Loaded the pipeline successfully") + def _run_workflow_async() -> None: + import signal + logging.info("Step1") + def handle_signal(signum, _): + # Handle the signal here + progress_reporter.info(f"Received signal {signum}, exiting...") + progress_reporter.dispose() + for task in asyncio.all_tasks(): + task.cancel() + progress_reporter.info("All tasks cancelled. Exiting...") + + # Register signal handlers for SIGINT and SIGHUP + logging.info("Step2") + #signal.signal(signal.SIGINT, handle_signal) + + logging.info("Step3") + if sys.platform != "win32": + signal.signal(signal.SIGHUP, handle_signal) + + logging.info("Step4") + async def execute(): + nonlocal encountered_errors + async for output in run_pipeline_with_config( + pipeline_config, + run_id=run_id, + memory_profile=memprofile, + cache=cache, + progress_reporter=progress_reporter, + emit=( + [TableEmitterType(e) for e in pipeline_emit] + if pipeline_emit + else None + ), + is_resume_run=bool(resume), + context_id=context_id, + ): + if output.errors and len(output.errors) > 0: + encountered_errors = True + progress_reporter.error(output.workflow) + else: + progress_reporter.success(output.workflow) + + progress_reporter.info(str(output.result)) + + if platform.system() == "Windows": + logging.info("All set to execute the workflows on Windows") + import nest_asyncio # type: ignore Ignoring because out of windows this will cause an error + + nest_asyncio.apply() + loop = asyncio.get_event_loop() + loop.run_until_complete(execute()) + elif sys.version_info >= (3, 11): + logging.info("Step6") + import uvloop # type: ignore Ignoring because on windows this will cause an error + + with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: # type: ignore Ignoring because minor versions this will throw an error + runner.run(execute()) + else: + logging.info("Step 6") + import uvloop # type: ignore Ignoring because on windows this will cause an error + + uvloop.install() + asyncio.run(execute()) + + _run_workflow_async() + progress_reporter.stop() + if encountered_errors: + progress_reporter.error( + "Errors occurred during the pipeline run, see logs for more details." + ) + else: + progress_reporter.success("All workflows completed successfully.") + + if cli: + sys.exit(1 if encountered_errors else 0) + +def _switch_context(root: str, config: str, + reporter: ProgressReporter, context_operation: str | None, + context_id: str, community_level: int, optimized_search: bool, + use_kusto_community_reports: bool) -> None: + """Switch the context to the given context.""" + reporter.info(f"Switching context to {context_id} using operation {context_operation}") + logging.info("Switching context to {context_id}") + from graphrag.index.context_switch.contextSwitcher import ContextSwitcher + context_switcher = ContextSwitcher( + root_dir=root, + config_dir=config, + reporter=reporter, + context_id=context_id, + community_level=community_level, + data_dir=None, + optimized_search=optimized_search, + use_kusto_community_reports=use_kusto_community_reports) + if context_operation == ContextSwitchType.Activate: + context_switcher.activate() + elif context_operation == ContextSwitchType.Deactivate: + context_switcher.deactivate() + else: + msg = f"Invalid context operation {context_operation}" + raise ValueError(msg) + +def _initialize_project_at(path: str, reporter: ProgressReporter) -> None: + """Initialize the project at the given path.""" + reporter.info(f"Initializing project at {path}") + root = Path(path) + if not root.exists(): + root.mkdir(parents=True, exist_ok=True) + + settings_yaml = root / "settings/settings.yaml" + + dotenv = root / ".env" + if not dotenv.exists(): + with settings_yaml.open("wb") as file: + file.write(INIT_YAML.encode(encoding="utf-8", errors="strict")) + + with dotenv.open("wb") as file: + file.write(INIT_DOTENV.encode(encoding="utf-8", errors="strict")) + + prompts_dir = root / "prompts" + if not prompts_dir.exists(): + prompts_dir.mkdir(parents=True, exist_ok=True) + + entity_extraction = prompts_dir / "entity_extraction.txt" + if not entity_extraction.exists(): + with entity_extraction.open("wb") as file: + file.write( + GRAPH_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict") + ) + + summarize_descriptions = prompts_dir / "summarize_descriptions.txt" + if not summarize_descriptions.exists(): + with summarize_descriptions.open("wb") as file: + file.write(SUMMARIZE_PROMPT.encode(encoding="utf-8", errors="strict")) + + claim_extraction = prompts_dir / "claim_extraction.txt" + if not claim_extraction.exists(): + with claim_extraction.open("wb") as file: + file.write( + CLAIM_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict") + ) + + community_report = prompts_dir / "community_report.txt" + if not community_report.exists(): + with community_report.open("wb") as file: + file.write( + COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict") + ) + + +def _create_default_config( + root: str, + config: str | None, + verbose: bool, + dryrun: bool, + reporter: ProgressReporter, +) -> PipelineConfig: + """Overlay default values on an existing config or create a default config if none is provided.""" + if config and not Path(config).exists(): + msg = f"Configuration file {config} does not exist" + raise ValueError + + if not Path(root).exists(): + msg = f"Root directory {root} does not exist" + raise ValueError(msg) + + parameters = _read_config_parameters(root, config, reporter) + log.info( + "using default configuration: %s", + redact(parameters.model_dump()), + ) + + if verbose or dryrun: + reporter.info(f"Using default configuration: {redact(parameters.model_dump())}") + result = create_pipeline_config(parameters, verbose) + if verbose or dryrun: + reporter.info(f"Final Config: {redact(result.model_dump())}") + + if dryrun: + reporter.info("dry run complete, exiting...") + sys.exit(0) + return result + + +def _read_config_parameters(root: str, config: str | None, reporter: ProgressReporter): + _root = Path(root) + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) + if not settings_yaml.exists(): + settings_yaml = _root / "settings.yml" + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) + + if settings_yaml.exists(): + reporter.success(f"Reading settings from {settings_yaml}") + with settings_yaml.open("rb") as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + if settings_json.exists(): + reporter.success(f"Reading settings from {settings_json}") + with settings_json.open("rb") as file: + import json + + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + reporter.success("Reading settings from environment variables") + return create_graphrag_config(root_dir=root) + + +def _get_progress_reporter(reporter_type: str | None) -> ProgressReporter: + if reporter_type is None or reporter_type == "rich": + return RichProgressReporter("GraphRAG Indexer ") + if reporter_type == "print": + return PrintProgressReporter("GraphRAG Indexer ") + if reporter_type == "none": + return NullProgressReporter() + + msg = f"Invalid progress reporter type: {reporter_type}" + raise ValueError(msg) + + +def _enable_logging(root_dir: str, run_id: str, verbose: bool) -> None: + logging_file = ( + Path(root_dir) / "output" / run_id / "reports" / "indexing-engine.log" + ) + logging_file.parent.mkdir(parents=True, exist_ok=True) + + logging_file.touch(exist_ok=True) + handler = logging.StreamHandler(stream=sys.stdout) + fileHandler = logging.FileHandler(logging_file, mode="a") + logging.basicConfig( + #filename=str(logging_file), + #filemode="a", + format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + level=logging.DEBUG if verbose else logging.INFO, + handlers=[handler, fileHandler] + ) diff --git a/func-app/graphrag/index/config/__init__.py b/func-app/graphrag/index/config/__init__.py new file mode 100644 index 0000000000..ad30859b81 --- /dev/null +++ b/func-app/graphrag/index/config/__init__.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine config typing package root.""" + +from .cache import ( + PipelineBlobCacheConfig, + PipelineCacheConfig, + PipelineCacheConfigTypes, + PipelineFileCacheConfig, + PipelineMemoryCacheConfig, + PipelineNoneCacheConfig, +) +from .input import ( + PipelineCSVInputConfig, + PipelineInputConfig, + PipelineInputConfigTypes, + PipelineTextInputConfig, +) +from .pipeline import PipelineConfig +from .reporting import ( + PipelineBlobReportingConfig, + PipelineConsoleReportingConfig, + PipelineFileReportingConfig, + PipelineReportingConfig, + PipelineReportingConfigTypes, +) +from ...common.config.storage import ( + PipelineBlobStorageConfig, + PipelineFileStorageConfig, + PipelineMemoryStorageConfig, + PipelineStorageConfig, + PipelineStorageConfigTypes, +) +from .workflow import ( + PipelineWorkflowConfig, + PipelineWorkflowReference, + PipelineWorkflowStep, +) + +__all__ = [ + "PipelineBlobCacheConfig", + "PipelineBlobReportingConfig", + "PipelineBlobStorageConfig", + "PipelineCSVInputConfig", + "PipelineCacheConfig", + "PipelineCacheConfigTypes", + "PipelineCacheConfigTypes", + "PipelineCacheConfigTypes", + "PipelineConfig", + "PipelineConsoleReportingConfig", + "PipelineFileCacheConfig", + "PipelineFileReportingConfig", + "PipelineFileStorageConfig", + "PipelineInputConfig", + "PipelineInputConfigTypes", + "PipelineMemoryCacheConfig", + "PipelineMemoryCacheConfig", + "PipelineMemoryStorageConfig", + "PipelineNoneCacheConfig", + "PipelineReportingConfig", + "PipelineReportingConfigTypes", + "PipelineStorageConfig", + "PipelineStorageConfigTypes", + "PipelineTextInputConfig", + "PipelineWorkflowConfig", + "PipelineWorkflowReference", + "PipelineWorkflowStep", +] diff --git a/func-app/graphrag/index/config/cache.py b/func-app/graphrag/index/config/cache.py new file mode 100644 index 0000000000..be1053de2e --- /dev/null +++ b/func-app/graphrag/index/config/cache.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineCacheConfig', 'PipelineFileCacheConfig' and 'PipelineMemoryCacheConfig' models.""" + +from __future__ import annotations + +from typing import Generic, Literal, TypeVar + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +from graphrag.config.enums import CacheType + +T = TypeVar("T") + + +class PipelineCacheConfig(BaseModel, Generic[T]): + """Represent the cache configuration for the pipeline.""" + + type: T + + +class PipelineFileCacheConfig(PipelineCacheConfig[Literal[CacheType.file]]): + """Represent the file cache configuration for the pipeline.""" + + type: Literal[CacheType.file] = CacheType.file + """The type of cache.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the cache.", default=None + ) + """The base directory for the cache.""" + + +class PipelineMemoryCacheConfig(PipelineCacheConfig[Literal[CacheType.memory]]): + """Represent the memory cache configuration for the pipeline.""" + + type: Literal[CacheType.memory] = CacheType.memory + """The type of cache.""" + + +class PipelineNoneCacheConfig(PipelineCacheConfig[Literal[CacheType.none]]): + """Represent the none cache configuration for the pipeline.""" + + type: Literal[CacheType.none] = CacheType.none + """The type of cache.""" + + +class PipelineBlobCacheConfig(PipelineCacheConfig[Literal[CacheType.blob]]): + """Represents the blob cache configuration for the pipeline.""" + + type: Literal[CacheType.blob] = CacheType.blob + """The type of cache.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the cache.", default=None + ) + """The base directory for the cache.""" + + connection_string: str | None = pydantic_Field( + description="The blob cache connection string for the cache.", default=None + ) + """The blob cache connection string for the cache.""" + + container_name: str = pydantic_Field( + description="The container name for cache", default=None + ) + """The container name for cache""" + + storage_account_blob_url: str | None = pydantic_Field( + description="The storage account blob url for cache", default=None + ) + """The storage account blob url for cache""" + + +PipelineCacheConfigTypes = ( + PipelineFileCacheConfig + | PipelineMemoryCacheConfig + | PipelineBlobCacheConfig + | PipelineNoneCacheConfig +) diff --git a/func-app/graphrag/index/config/input.py b/func-app/graphrag/index/config/input.py new file mode 100644 index 0000000000..35db357599 --- /dev/null +++ b/func-app/graphrag/index/config/input.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineInputConfig', 'PipelineCSVInputConfig' and 'PipelineTextInputConfig' models.""" + +from __future__ import annotations + +from typing import Generic, Literal, TypeVar + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +from graphrag.config.enums import InputFileType, InputType + +from .workflow import PipelineWorkflowStep + +T = TypeVar("T") + + +class PipelineInputConfig(BaseModel, Generic[T]): + """Represent the configuration for an input.""" + + file_type: T + """The file type of input.""" + + type: InputType | None = pydantic_Field( + description="The input type to use.", + default=None, + ) + """The input type to use.""" + + connection_string: str | None = pydantic_Field( + description="The blob cache connection string for the input files.", + default=None, + ) + """The blob cache connection string for the input files.""" + + storage_account_blob_url: str | None = pydantic_Field( + description="The storage account blob url for the input files.", default=None + ) + """The storage account blob url for the input files.""" + + container_name: str | None = pydantic_Field( + description="The container name for input files.", default=None + ) + """The container name for the input files.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the input files.", default=None + ) + """The base directory for the input files.""" + + file_pattern: str = pydantic_Field( + description="The regex file pattern for the input files." + ) + """The regex file pattern for the input files.""" + + file_filter: dict[str, str] | None = pydantic_Field( + description="The optional file filter for the input files.", default=None + ) + """The optional file filter for the input files.""" + + post_process: list[PipelineWorkflowStep] | None = pydantic_Field( + description="The post processing steps for the input.", default=None + ) + """The post processing steps for the input.""" + + encoding: str | None = pydantic_Field( + description="The encoding for the input files.", default=None + ) + """The encoding for the input files.""" + + +class PipelineCSVInputConfig(PipelineInputConfig[Literal[InputFileType.csv]]): + """Represent the configuration for a CSV input.""" + + file_type: Literal[InputFileType.csv] = InputFileType.csv + + source_column: str | None = pydantic_Field( + description="The column to use as the source of the document.", default=None + ) + """The column to use as the source of the document.""" + + timestamp_column: str | None = pydantic_Field( + description="The column to use as the timestamp of the document.", default=None + ) + """The column to use as the timestamp of the document.""" + + timestamp_format: str | None = pydantic_Field( + description="The format of the timestamp column, so it can be parsed correctly.", + default=None, + ) + """The format of the timestamp column, so it can be parsed correctly.""" + + text_column: str | None = pydantic_Field( + description="The column to use as the text of the document.", default=None + ) + """The column to use as the text of the document.""" + + title_column: str | None = pydantic_Field( + description="The column to use as the title of the document.", default=None + ) + """The column to use as the title of the document.""" + + +class PipelineTextInputConfig(PipelineInputConfig[Literal[InputFileType.text]]): + """Represent the configuration for a text input.""" + + file_type: Literal[InputFileType.text] = InputFileType.text + + # Text Specific + title_text_length: int | None = pydantic_Field( + description="Number of characters to use from the text as the title.", + default=None, + ) + """Number of characters to use from the text as the title.""" + + +PipelineInputConfigTypes = PipelineCSVInputConfig | PipelineTextInputConfig +"""Represent the types of inputs that can be used in a pipeline.""" diff --git a/func-app/graphrag/index/config/pipeline.py b/func-app/graphrag/index/config/pipeline.py new file mode 100644 index 0000000000..e8bbbdbf4c --- /dev/null +++ b/func-app/graphrag/index/config/pipeline.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineConfig' model.""" + +from __future__ import annotations + +from devtools import pformat +from graphrag.config.models.graphdb_config import GraphDBConfig +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +from .cache import PipelineCacheConfigTypes +from .input import PipelineInputConfigTypes +from .reporting import PipelineReportingConfigTypes +from ...common.config.storage import PipelineStorageConfigTypes +from .workflow import PipelineWorkflowReference + + +class PipelineConfig(BaseModel): + """Represent the configuration for a pipeline.""" + + def __repr__(self) -> str: + """Get a string representation.""" + return pformat(self, highlight=False) + + def __str__(self): + """Get a string representation.""" + return str(self.model_dump_json(indent=4)) + + extends: list[str] | str | None = pydantic_Field( + description="Extends another pipeline configuration", default=None + ) + """Extends another pipeline configuration""" + + input: PipelineInputConfigTypes | None = pydantic_Field( + default=None, discriminator="file_type" + ) + """The input configuration for the pipeline.""" + + reporting: PipelineReportingConfigTypes | None = pydantic_Field( + default=None, discriminator="type" + ) + """The reporting configuration for the pipeline.""" + + storage: PipelineStorageConfigTypes | None = pydantic_Field( + default=None, discriminator="type" + ) + """The storage configuration for the pipeline.""" + + cache: PipelineCacheConfigTypes | None = pydantic_Field( + default=None, discriminator="type" + ) + """The cache configuration for the pipeline.""" + + root_dir: str | None = pydantic_Field( + description="The root directory for the pipeline. All other paths will be based on this root_dir.", + default=None, + ) + """The root directory for the pipeline.""" + + workflows: list[PipelineWorkflowReference] = pydantic_Field( + description="The workflows for the pipeline.", default_factory=list + ) + """The workflows for the pipeline.""" + + graphdb_params: GraphDBConfig|None = pydantic_Field( + description="Parameters for Graphdb collection", default=None + ) + """Parameters for Graphdb collection""" diff --git a/func-app/graphrag/index/config/reporting.py b/func-app/graphrag/index/config/reporting.py new file mode 100644 index 0000000000..921e24ae4e --- /dev/null +++ b/func-app/graphrag/index/config/reporting.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineReportingConfig', 'PipelineFileReportingConfig' and 'PipelineConsoleReportingConfig' models.""" + +from __future__ import annotations + +from typing import Generic, Literal, TypeVar + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +from graphrag.config.enums import ReportingType + +T = TypeVar("T") + + +class PipelineReportingConfig(BaseModel, Generic[T]): + """Represent the reporting configuration for the pipeline.""" + + type: T + + +class PipelineFileReportingConfig(PipelineReportingConfig[Literal[ReportingType.file]]): + """Represent the file reporting configuration for the pipeline.""" + + type: Literal[ReportingType.file] = ReportingType.file + """The type of reporting.""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the reporting.", default=None + ) + """The base directory for the reporting.""" + + +class PipelineConsoleReportingConfig( + PipelineReportingConfig[Literal[ReportingType.console]] +): + """Represent the console reporting configuration for the pipeline.""" + + type: Literal[ReportingType.console] = ReportingType.console + """The type of reporting.""" + + +class PipelineBlobReportingConfig(PipelineReportingConfig[Literal[ReportingType.blob]]): + """Represents the blob reporting configuration for the pipeline.""" + + type: Literal[ReportingType.blob] = ReportingType.blob + """The type of reporting.""" + + connection_string: str | None = pydantic_Field( + description="The blob reporting connection string for the reporting.", + default=None, + ) + """The blob reporting connection string for the reporting.""" + + container_name: str = pydantic_Field( + description="The container name for reporting", default=None + ) + """The container name for reporting""" + + storage_account_blob_url: str | None = pydantic_Field( + description="The storage account blob url for reporting", default=None + ) + """The storage account blob url for reporting""" + + base_dir: str | None = pydantic_Field( + description="The base directory for the reporting.", default=None + ) + """The base directory for the reporting.""" + + +PipelineReportingConfigTypes = ( + PipelineFileReportingConfig + | PipelineConsoleReportingConfig + | PipelineBlobReportingConfig +) diff --git a/func-app/graphrag/index/config/workflow.py b/func-app/graphrag/index/config/workflow.py new file mode 100644 index 0000000000..c26fef6ca0 --- /dev/null +++ b/func-app/graphrag/index/config/workflow.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineWorkflowReference' model.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel +from pydantic import Field as pydantic_Field + +PipelineWorkflowStep = dict[str, Any] +"""Represent a step in a workflow.""" + +PipelineWorkflowConfig = dict[str, Any] +"""Represent a configuration for a workflow.""" + + +class PipelineWorkflowReference(BaseModel): + """Represent a reference to a workflow, and can optionally be the workflow itself.""" + + name: str | None = pydantic_Field(description="Name of the workflow.", default=None) + """Name of the workflow.""" + + steps: list[PipelineWorkflowStep] | None = pydantic_Field( + description="The optional steps for the workflow.", default=None + ) + """The optional steps for the workflow.""" + + config: PipelineWorkflowConfig | None = pydantic_Field( + description="The optional configuration for the workflow.", default=None + ) + """The optional configuration for the workflow.""" diff --git a/func-app/graphrag/index/context.py b/func-app/graphrag/index/context.py new file mode 100644 index 0000000000..cdec0f6292 --- /dev/null +++ b/func-app/graphrag/index/context.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +# isort: skip_file +"""A module containing the 'PipelineRunStats' and 'PipelineRunContext' models.""" + +from dataclasses import dataclass as dc_dataclass +from dataclasses import field + +from .cache import PipelineCache +from graphrag.common.storage.typing import PipelineStorage + + +@dc_dataclass +class PipelineRunStats: + """Pipeline running stats.""" + + total_runtime: float = field(default=0) + """Float representing the total runtime.""" + + num_documents: int = field(default=0) + """Number of documents.""" + + input_load_time: float = field(default=0) + """Float representing the input load time.""" + + workflows: dict[str, dict[str, float]] = field(default_factory=dict) + """A dictionary of workflows.""" + + +@dc_dataclass +class PipelineRunContext: + """Provides the context for the current pipeline run.""" + + stats: PipelineRunStats + storage: PipelineStorage + cache: PipelineCache + + +# TODO: For now, just has the same props available to it +VerbRunContext = PipelineRunContext +"""Provides the context for the current verb run.""" diff --git a/func-app/graphrag/index/context_switch/contextSwitcher.py b/func-app/graphrag/index/context_switch/contextSwitcher.py new file mode 100644 index 0000000000..26c007e428 --- /dev/null +++ b/func-app/graphrag/index/context_switch/contextSwitcher.py @@ -0,0 +1,288 @@ +import asyncio +import os +from io import BytesIO +from pathlib import Path +from typing import cast + +import pandas as pd + +from common.graph_db_client import GraphDBClient +from graphrag.common.progress import ProgressReporter +from graphrag.common.storage import ( + BlobPipelineStorage, + FilePipelineStorage, + PipelineStorage, +) +from graphrag.common.utils.context_utils import get_files_by_contextid +from graphrag.config import ( + GraphRagConfig, + create_graphrag_config, +) +from graphrag.config.enums import StorageType +from graphrag.model.community_report import CommunityReport +from graphrag.model import TextUnit +from graphrag.model.entity import Entity +from graphrag.query.indexer_adapters import ( + read_indexer_entities, + read_indexer_reports, + read_indexer_text_units, +) +from graphrag.model.entity import Entity +from azure.cosmos import CosmosClient, PartitionKey +from graphrag.vector_stores.base import BaseVectorStore +from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType +import logging + +class ContextSwitcher: + """ContextSwitcher class definition.""" + + def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter, + context_id:str, community_level:int , + data_dir: str = None, + optimized_search: bool= False, + use_kusto_community_reports: bool = False,): + + self.root_dir=root_dir + self.config_dir=config_dir + self.data_dir=data_dir + self.reporter=reporter + self.context_id=context_id + self.optimized_search=optimized_search + self.community_level = community_level + self.use_kusto_community_reports = use_kusto_community_reports + logging.info("ContextSwitcher initialized") + + def get_embedding_store(self,config_args): + """Set up the vector store and return it.""" + if not config_args: + config_args = {} + + collection_name = config_args.get( + "query_collection_name", "entity_description_embeddings" + ) + + collection_name += "_" + self.context_id + config_args.update({"collection_name": collection_name}) + + vector_name = config_args.get( + "vector_search_column", "description_embedding" + ) + config_args.update({"vector_name": vector_name}) + config_args.update({"reports_name": f"reports_{self.context_id}"}) + + + config_args.update({"text_units_name": f"text_units_{self.context_id}"}) + + return VectorStoreFactory.get_vector_store( + vector_store_type=VectorStoreType.Kusto, kwargs=config_args + ) + + + + def setup_vector_store(self, + config_args: dict | None = None,) -> BaseVectorStore: + + description_embedding_store = self.get_embedding_store(config_args) + description_embedding_store.connect(**config_args) + + description_embedding_store.setup_entities() + if self.use_kusto_community_reports: + description_embedding_store.setup_reports() + + description_embedding_store.setup_text_units() + + return description_embedding_store + + def _read_config_parameters(self,root: str, config: str | None): + reporter=self.reporter + _root = Path(root) + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) + if not settings_yaml.exists(): + settings_yaml = _root / "settings.yml" + + if settings_yaml.exists(): + reporter.info(f"Reading settings from {settings_yaml}") + with settings_yaml.open( + "rb", + ) as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) + if settings_json.exists(): + reporter.info(f"Reading settings from {settings_json}") + with settings_json.open("rb") as file: + import json + + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + reporter.info("Reading settings from environment variables") + return create_graphrag_config(root_dir=root) + + def activate(self): + """Activate the context.""" + #1. read the context id to fileId mapping. + #2. read the file from storage using common/blob_storage_client.py + #3. GraphDB: use cosmos db client to load data into Cosmos DB. + #4. KustoDB: use Kusto client to load embedding data into Kusto. + data_dir=self.data_dir + root_dir=self.root_dir + config_dir=self.config_dir + reporter=self.reporter + context_id=self.context_id + optimized_search=self.optimized_search + community_level=self.community_level + + def read_paraquet_file(storage: PipelineStorage, path: str): + #create different enum for paraquet storage type + file_data = asyncio.run(storage.get(path, True)) + if file_data is None: + return pd.DataFrame() + return pd.read_parquet(BytesIO(file_data), engine="pyarrow") + + def _configure_paths_and_settings( + data_dir: str | None, + root_dir: str | None, + config_dir: str | None, + ) -> tuple[str, str | None, GraphRagConfig]: + if data_dir is None and root_dir is None: + msg = "Either data_dir or root_dir must be provided." + raise ValueError(msg) + if data_dir is None: + data_dir = _infer_data_dir(cast(str, root_dir)) + config = _create_graphrag_config(root_dir, config_dir) + return data_dir, root_dir, config + + + def _infer_data_dir(root: str) -> str: + output = Path(root) / "output" + # use the latest data-run folder + if output.exists(): + folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True) + if len(folders) > 0: + folder = folders[0] + return str((folder / "artifacts").absolute()) + msg = f"Could not infer data directory from root={root}" + raise ValueError(msg) + + + def _create_graphrag_config( + root: str | None, + config_dir: str | None, + ) -> GraphRagConfig: + """Create a GraphRag configuration.""" + return self._read_config_parameters(root or "./", config_dir) + + ################################################################################ + + + _, _, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + + if(config.storage.type == StorageType.memory): + ValueError("Memory storage is not supported") + if(config.storage.type == StorageType.blob): + if(config.storage.container_name is not None): + input_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, container_name=config.storage.container_name, storage_account_blob_url=config.storage.storage_account_blob_url) + else: + ValueError("Storage type is Blob but container name is invalid") + if(config.storage.type == StorageType.file): + input_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + + data_paths = [] + data_paths = get_files_by_contextid(config, context_id) + final_nodes = pd.DataFrame() + final_community_reports = pd.DataFrame() + final_text_units = pd.DataFrame() + final_relationships = pd.DataFrame() + final_entities = pd.DataFrame() + final_covariates = pd.DataFrame() + graph_db_client=None + + if config.graphdb.enabled: + cosmos_client = CosmosClient( + f"{config.graphdb.cosmos_url}", + f"{config.graphdb.account_key}", + ) + database_name = config.graphdb.username.split("/")[2] + database = cosmos_client.get_database_client(database_name) + graph_name=config.graphdb.username.split("/")[-1]+"-contextid-"+context_id + graph = database.create_container_if_not_exists( + id=graph_name, + partition_key=PartitionKey(path='/category'), + offer_throughput=400 + ) + graph_db_client = GraphDBClient(config.graphdb,context_id) + + description_embedding_store = self.setup_vector_store(config_args=config.embeddings.vector_store) + + for data_path in data_paths: + #check from the config for the ouptut storage type and then read the data from the storage. + + #GraphDB: we may need to make change below to read nodes data from Graph DB + final_nodes = read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet") + final_community_reports = read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet") # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto + final_text_units = read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet") # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. + + if not optimized_search: + final_covariates = read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet") + + final_relationships = read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet") + final_entities = read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet") + + vector_store_args = ( + config.embeddings.vector_store if config.embeddings.vector_store else {} + ) + + reporter.info(f"Vector Store Args: {vector_store_args}") + + if "type" not in vector_store_args: + ValueError("vectore_store.type can't be empty") + + vector_store_type = vector_store_args.get("type") + + if vector_store_type != VectorStoreType.Kusto: + ValueError("Context switching is only supporeted for vectore_store.type=kusto ") + + entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. + reports = read_indexer_reports(final_community_reports, final_nodes, community_level) + text_units = read_indexer_text_units(final_text_units) + + description_embedding_store.load_entities(entities) + if self.use_kusto_community_reports: + raise ValueError("Community reports not supported for kusto.") + #description_embedding_store.load_reports(reports) + + description_embedding_store.load_text_units(text_units) + + if config.graphdb.enabled: + graph_db_client.write_vertices(final_entities) + graph_db_client.write_edges(final_relationships) + + if config.graphdb.enabled: + graph_db_client._client.close() + + def deactivate(self): + """DeActivate the context.""" + + config=self._read_config_parameters(self.root_dir or "./",self.config_dir) + config_args = config.embeddings.vector_store + description_embedding_store = self.get_embedding_store(config_args) + description_embedding_store.connect(**config_args) + description_embedding_store.unload_entities() + + if config.graphdb.enabled: + g_client=GraphDBClient(config.graphdb,self.context_id) + g_client.remove_graph() \ No newline at end of file diff --git a/func-app/graphrag/index/create_pipeline_config.py b/func-app/graphrag/index/create_pipeline_config.py new file mode 100644 index 0000000000..7cf91ec308 --- /dev/null +++ b/func-app/graphrag/index/create_pipeline_config.py @@ -0,0 +1,595 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Default configuration methods definition.""" + +import json +import logging +from pathlib import Path +from .emit.types import TableEmitterType + +from graphrag.config.enums import ( + CacheType, + InputFileType, + ReportingType, + StorageType, + TextEmbeddingTarget, +) +from graphrag.config.models import ( + GraphRagConfig, + TextEmbeddingConfig, +) +from graphrag.index.config.cache import ( + PipelineBlobCacheConfig, + PipelineCacheConfigTypes, + PipelineFileCacheConfig, + PipelineMemoryCacheConfig, + PipelineNoneCacheConfig, +) +from graphrag.index.config.input import ( + PipelineCSVInputConfig, + PipelineInputConfigTypes, + PipelineTextInputConfig, +) +from graphrag.index.config.pipeline import ( + PipelineConfig, +) +from graphrag.index.config.reporting import ( + PipelineBlobReportingConfig, + PipelineConsoleReportingConfig, + PipelineFileReportingConfig, + PipelineReportingConfigTypes, +) +from graphrag.common.config.storage import ( + PipelineBlobStorageConfig, + PipelineFileStorageConfig, + PipelineMemoryStorageConfig, + PipelineStorageConfigTypes, +) +from graphrag.index.config.workflow import ( + PipelineWorkflowReference, +) +from graphrag.index.workflows.default_workflows import ( + create_base_documents, + create_base_entity_graph, + create_base_extracted_entities, + create_base_text_units, + create_final_communities, + create_final_community_reports, + create_final_covariates, + create_final_documents, + create_final_entities, + create_final_nodes, + create_final_relationships, + create_final_text_units, + create_summarized_entities, + join_text_units_to_covariate_ids, + join_text_units_to_entity_ids, + join_text_units_to_relationship_ids, +) + +log = logging.getLogger(__name__) + + +entity_name_embedding = "entity.name" +entity_description_embedding = "entity.description" +relationship_description_embedding = "relationship.description" +document_raw_content_embedding = "document.raw_content" +community_title_embedding = "community.title" +community_summary_embedding = "community.summary" +community_full_content_embedding = "community.full_content" +text_unit_text_embedding = "text_unit.text" + +all_embeddings: set[str] = { + entity_name_embedding, + entity_description_embedding, + relationship_description_embedding, + document_raw_content_embedding, + community_title_embedding, + community_summary_embedding, + community_full_content_embedding, + text_unit_text_embedding, +} +required_embeddings: set[str] = {entity_description_embedding} + + +builtin_document_attributes: set[str] = { + "id", + "source", + "text", + "title", + "timestamp", + "year", + "month", + "day", + "hour", + "minute", + "second", +} + + +def create_pipeline_config(settings: GraphRagConfig, verbose=False) -> PipelineConfig: + """Get the default config for the pipeline.""" + # relative to the root_dir + if verbose: + _log_llm_settings(settings) + + skip_workflows = _determine_skip_workflows(settings) + embedded_fields = _get_embedded_fields(settings) + covariates_enabled = ( + settings.claim_extraction.enabled + and create_final_covariates not in skip_workflows + ) + + result = PipelineConfig( + root_dir=settings.root_dir, + input=_get_pipeline_input_config(settings), + reporting=_get_reporting_config(settings), + storage=_get_storage_config(settings), + cache=_get_cache_config(settings), + workflows=[ + *_document_workflows(settings, embedded_fields), + *_text_unit_workflows(settings, covariates_enabled, embedded_fields), + *_graph_workflows(settings, embedded_fields), + *_community_workflows(settings, covariates_enabled, embedded_fields), + *(_covariate_workflows(settings) if covariates_enabled else []), + ], + graphdb_params=settings.graphdb + ) + + # Remove any workflows that were specified to be skipped + log.info("skipping workflows %s", ",".join(skip_workflows)) + result.workflows = [w for w in result.workflows if w.name not in skip_workflows] + return result + + +def _get_embedded_fields(settings: GraphRagConfig) -> set[str]: + match settings.embeddings.target: + case TextEmbeddingTarget.all: + return all_embeddings - {*settings.embeddings.skip} + case TextEmbeddingTarget.required: + return required_embeddings + case _: + msg = f"Unknown embeddings target: {settings.embeddings.target}" + raise ValueError(msg) + + +def _determine_skip_workflows(settings: GraphRagConfig) -> list[str]: + skip_workflows = settings.skip_workflows + if ( + create_final_covariates in skip_workflows + and join_text_units_to_covariate_ids not in skip_workflows + ): + skip_workflows.append(join_text_units_to_covariate_ids) + return skip_workflows + + +def _log_llm_settings(settings: GraphRagConfig) -> None: + log.info( + "Using LLM Config %s", + json.dumps( + {**settings.entity_extraction.llm.model_dump(), "api_key": "*****"}, + indent=4, + ), + ) + log.info( + "Using Embeddings Config %s", + json.dumps( + {**settings.embeddings.llm.model_dump(), "api_key": "*****"}, indent=4 + ), + ) + + +def _document_workflows( + settings: GraphRagConfig, embedded_fields: set[str] +) -> list[PipelineWorkflowReference]: + skip_document_raw_content_embedding = ( + document_raw_content_embedding not in embedded_fields + ) + return [ + PipelineWorkflowReference( + name=create_base_documents, + config={ + "document_attribute_columns": list( + {*(settings.input.document_attribute_columns)} + - builtin_document_attributes + ) + }, + ), + PipelineWorkflowReference( + name=create_final_documents, + config={ + "document_raw_content_embed": _get_embedding_settings( + settings.embeddings, + "document_raw_content", + { + "title_column": "raw_content", + "collection_name": "final_documents_raw_content_embedding", + }, + ), + "skip_raw_content_embedding": skip_document_raw_content_embedding, + }, + ), + ] + + +def _text_unit_workflows( + settings: GraphRagConfig, + covariates_enabled: bool, + embedded_fields: set[str], +) -> list[PipelineWorkflowReference]: + skip_text_unit_embedding = text_unit_text_embedding not in embedded_fields + return [ + PipelineWorkflowReference( + name=create_base_text_units, + config={ + "chunk_by": settings.chunks.group_by_columns, + "text_chunk": { + "strategy": settings.chunks.resolved_strategy( + settings.encoding_model + ) + }, + }, + ), + PipelineWorkflowReference( + name=join_text_units_to_entity_ids, + ), + PipelineWorkflowReference( + name=join_text_units_to_relationship_ids, + ), + *( + [ + PipelineWorkflowReference( + name=join_text_units_to_covariate_ids, + ) + ] + if covariates_enabled + else [] + ), + PipelineWorkflowReference( + name=create_final_text_units, + config={ + "text_unit_text_embed": _get_embedding_settings( + settings.embeddings, + "text_unit_text", + {"title_column": "text", "collection_name": "text_units_embedding"}, + ), + "covariates_enabled": covariates_enabled, + "skip_text_unit_embedding": skip_text_unit_embedding, + }, + ), + ] + + +def _get_embedding_settings( + settings: TextEmbeddingConfig, + embedding_name: str, + vector_store_params: dict | None = None, +) -> dict: + vector_store_settings = settings.vector_store + if vector_store_settings is None: + return {"strategy": settings.resolved_strategy()} + # + # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. + # settings.vector_store.base contains connection information, or may be undefined + # settings.vector_store. contains the specific settings for this embedding + # + strategy = settings.resolved_strategy() # get the default strategy + strategy.update({ + "vector_store": {**vector_store_settings, **(vector_store_params or {})} + }) # update the default strategy with the vector store settings + # This ensures the vector store config is part of the strategy and not the global config + return { + "strategy": strategy, + "embedding_name": embedding_name, + } + + +def _graph_workflows( + settings: GraphRagConfig, embedded_fields: set[str] +) -> list[PipelineWorkflowReference]: + skip_entity_name_embedding = entity_name_embedding not in embedded_fields + skip_entity_description_embedding = ( + entity_description_embedding not in embedded_fields + ) + skip_relationship_description_embedding = ( + relationship_description_embedding not in embedded_fields + ) + return [ + PipelineWorkflowReference( + name=create_base_extracted_entities, + config={ + "graphml_snapshot": settings.snapshots.graphml, + "raw_entity_snapshot": settings.snapshots.raw_entities, + "entity_extract": { + **settings.entity_extraction.parallelization.model_dump(), + "async_mode": settings.entity_extraction.async_mode, + "strategy": settings.entity_extraction.resolved_strategy( + settings.root_dir, settings.encoding_model + ), + "entity_types": settings.entity_extraction.entity_types, + }, + }, + ), + PipelineWorkflowReference( + name=create_summarized_entities, + config={ + "graphml_snapshot": settings.snapshots.graphml, + "summarize_descriptions": { + **settings.summarize_descriptions.parallelization.model_dump(), + "async_mode": settings.summarize_descriptions.async_mode, + "strategy": settings.summarize_descriptions.resolved_strategy( + settings.root_dir, + ), + }, + }, + ), + PipelineWorkflowReference( + name=create_base_entity_graph, + config={ + "graphml_snapshot": settings.snapshots.graphml, + "embed_graph_enabled": settings.embed_graph.enabled, + "cluster_graph": { + "strategy": settings.cluster_graph.resolved_strategy() + }, + "embed_graph": {"strategy": settings.embed_graph.resolved_strategy()}, + }, + ), + PipelineWorkflowReference( + name=create_final_entities, + config={ + "entity_name_embed": _get_embedding_settings( + settings.embeddings, + "entity_name", + { + "title_column": "name", + "collection_name": "entity_name_embeddings", + }, + ), + "entity_name_description_embed": _get_embedding_settings( + settings.embeddings, + "entity_name_description", + { + "title_column": "description", + "collection_name": "entity_description_embeddings", + "vector_name": "vector", + "reports_name": "reports", + }, + ), + "skip_name_embedding": skip_entity_name_embedding, + "skip_description_embedding": skip_entity_description_embedding, + "emitter_type": TableEmitterType.Graphdb, + }, + ), + PipelineWorkflowReference( + name=create_final_relationships, + config={ + "relationship_description_embed": _get_embedding_settings( + settings.embeddings, + "relationship_description", + { + "title_column": "description", + "collection_name": "relationships_description_embeddings", + }, + ), + "skip_description_embedding": skip_relationship_description_embedding, + "emitter_type": TableEmitterType.Graphdb, + }, + ), + PipelineWorkflowReference( + name=create_final_nodes, + config={ + "layout_graph_enabled": settings.umap.enabled, + "snapshot_top_level_nodes": settings.snapshots.top_level_nodes, + }, + ), + ] + + +def _community_workflows( + settings: GraphRagConfig, covariates_enabled: bool, embedded_fields: set[str] +) -> list[PipelineWorkflowReference]: + skip_community_title_embedding = community_title_embedding not in embedded_fields + skip_community_summary_embedding = ( + community_summary_embedding not in embedded_fields + ) + skip_community_full_content_embedding = ( + community_full_content_embedding not in embedded_fields + ) + return [ + PipelineWorkflowReference(name=create_final_communities), + PipelineWorkflowReference( + name=create_final_community_reports, + config={ + "covariates_enabled": covariates_enabled, + "skip_title_embedding": skip_community_title_embedding, + "skip_summary_embedding": skip_community_summary_embedding, + "skip_full_content_embedding": skip_community_full_content_embedding, + "create_community_reports": { + **settings.community_reports.parallelization.model_dump(), + "async_mode": settings.community_reports.async_mode, + "strategy": settings.community_reports.resolved_strategy( + settings.root_dir + ), + }, + "community_report_full_content_embed": _get_embedding_settings( + settings.embeddings, + "community_report_full_content", + { + "title_column": "full_content", + "collection_name": "final_community_reports_full_content_embedding", + }, + ), + "community_report_summary_embed": _get_embedding_settings( + settings.embeddings, + "community_report_summary", + { + "title_column": "summary", + "collection_name": "final_community_reports_summary_embedding", + }, + ), + "community_report_title_embed": _get_embedding_settings( + settings.embeddings, + "community_report_title", + {"title_column": "title"}, + ), + }, + ), + ] + + +def _covariate_workflows( + settings: GraphRagConfig, +) -> list[PipelineWorkflowReference]: + return [ + PipelineWorkflowReference( + name=create_final_covariates, + config={ + "claim_extract": { + **settings.claim_extraction.parallelization.model_dump(), + "strategy": settings.claim_extraction.resolved_strategy( + settings.root_dir, settings.encoding_model + ), + }, + }, + ) + ] + + +def _get_pipeline_input_config( + settings: GraphRagConfig, +) -> PipelineInputConfigTypes: + file_type = settings.input.file_type + match file_type: + case InputFileType.csv: + return PipelineCSVInputConfig( + base_dir=settings.input.base_dir, + file_pattern=settings.input.file_pattern, + encoding=settings.input.encoding, + source_column=settings.input.source_column, + timestamp_column=settings.input.timestamp_column, + timestamp_format=settings.input.timestamp_format, + text_column=settings.input.text_column, + title_column=settings.input.title_column, + type=settings.input.type, + connection_string=settings.input.connection_string, + storage_account_blob_url=settings.input.storage_account_blob_url, + container_name=settings.input.container_name, + ) + case InputFileType.text: + return PipelineTextInputConfig( + base_dir=settings.input.base_dir, + file_pattern=settings.input.file_pattern, + encoding=settings.input.encoding, + type=settings.input.type, + connection_string=settings.input.connection_string, + storage_account_blob_url=settings.input.storage_account_blob_url, + container_name=settings.input.container_name, + ) + case _: + msg = f"Unknown input type: {file_type}" + raise ValueError(msg) + + +def _get_reporting_config( + settings: GraphRagConfig, +) -> PipelineReportingConfigTypes: + """Get the reporting config from the settings.""" + match settings.reporting.type: + case ReportingType.file: + # relative to the root_dir + return PipelineFileReportingConfig(base_dir=settings.reporting.base_dir) + case ReportingType.blob: + connection_string = settings.reporting.connection_string + storage_account_blob_url = settings.reporting.storage_account_blob_url + container_name = settings.reporting.container_name + if container_name is None: + msg = "Container name must be provided for blob reporting." + raise ValueError(msg) + if connection_string is None and storage_account_blob_url is None: + msg = "Connection string or storage account blob url must be provided for blob reporting." + raise ValueError(msg) + return PipelineBlobReportingConfig( + connection_string=connection_string, + container_name=container_name, + base_dir=settings.reporting.base_dir, + storage_account_blob_url=storage_account_blob_url, + ) + case ReportingType.console: + return PipelineConsoleReportingConfig() + case _: + # relative to the root_dir + return PipelineFileReportingConfig(base_dir=settings.reporting.base_dir) + + +def _get_storage_config( + settings: GraphRagConfig, +) -> PipelineStorageConfigTypes: + """Get the storage type from the settings.""" + root_dir = settings.root_dir + match settings.storage.type: + case StorageType.memory: + return PipelineMemoryStorageConfig() + case StorageType.file: + # relative to the root_dir + base_dir = settings.storage.base_dir + if base_dir is None: + msg = "Base directory must be provided for file storage." + raise ValueError(msg) + return PipelineFileStorageConfig(base_dir=str(Path(root_dir) / base_dir)) + case StorageType.blob: + connection_string = settings.storage.connection_string + storage_account_blob_url = settings.storage.storage_account_blob_url + container_name = settings.storage.container_name + if container_name is None: + msg = "Container name must be provided for blob storage." + raise ValueError(msg) + if connection_string is None and storage_account_blob_url is None: + msg = "Connection string or storage account blob url must be provided for blob storage." + raise ValueError(msg) + return PipelineBlobStorageConfig( + connection_string=connection_string, + container_name=container_name, + base_dir=settings.storage.base_dir, + storage_account_blob_url=storage_account_blob_url, + ) + case _: + # relative to the root_dir + base_dir = settings.storage.base_dir + if base_dir is None: + msg = "Base directory must be provided for file storage." + raise ValueError(msg) + return PipelineFileStorageConfig(base_dir=str(Path(root_dir) / base_dir)) + + +def _get_cache_config( + settings: GraphRagConfig, +) -> PipelineCacheConfigTypes: + """Get the cache type from the settings.""" + match settings.cache.type: + case CacheType.memory: + return PipelineMemoryCacheConfig() + case CacheType.file: + # relative to root dir + return PipelineFileCacheConfig(base_dir=settings.cache.base_dir) + case CacheType.none: + return PipelineNoneCacheConfig() + case CacheType.blob: + connection_string = settings.cache.connection_string + storage_account_blob_url = settings.cache.storage_account_blob_url + container_name = settings.cache.container_name + if container_name is None: + msg = "Container name must be provided for blob cache." + raise ValueError(msg) + if connection_string is None and storage_account_blob_url is None: + msg = "Connection string or storage account blob url must be provided for blob cache." + raise ValueError(msg) + return PipelineBlobCacheConfig( + connection_string=connection_string, + container_name=container_name, + base_dir=settings.cache.base_dir, + storage_account_blob_url=storage_account_blob_url, + ) + case _: + # relative to root dir + return PipelineFileCacheConfig(base_dir="./cache") diff --git a/func-app/graphrag/index/emit/__init__.py b/func-app/graphrag/index/emit/__init__.py new file mode 100644 index 0000000000..354989e338 --- /dev/null +++ b/func-app/graphrag/index/emit/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Definitions for emitting pipeline artifacts to storage.""" + +from .csv_table_emitter import CSVTableEmitter +from .factories import create_table_emitter, create_table_emitters +from .json_table_emitter import JsonTableEmitter +from .parquet_table_emitter import ParquetTableEmitter +from .table_emitter import TableEmitter +from .types import TableEmitterType + +__all__ = [ + "CSVTableEmitter", + "JsonTableEmitter", + "ParquetTableEmitter", + "TableEmitter", + "TableEmitterType", + "create_table_emitter", + "create_table_emitters", +] diff --git a/func-app/graphrag/index/emit/csv_table_emitter.py b/func-app/graphrag/index/emit/csv_table_emitter.py new file mode 100644 index 0000000000..0c208d1264 --- /dev/null +++ b/func-app/graphrag/index/emit/csv_table_emitter.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""CSVTableEmitter module.""" + +import logging + +import pandas as pd + +from graphrag.common.storage import PipelineStorage + +from .table_emitter import TableEmitter + +log = logging.getLogger(__name__) + + +class CSVTableEmitter(TableEmitter): + """CSVTableEmitter class.""" + + _storage: PipelineStorage + + def __init__(self, storage: PipelineStorage): + """Create a new CSV Table Emitter.""" + self._storage = storage + + async def emit(self, name: str, data: pd.DataFrame) -> None: + """Emit a dataframe to storage.""" + filename = f"{name}.csv" + log.info("emitting CSV table %s", filename) + await self._storage.set( + filename, + data.to_csv(), + ) diff --git a/func-app/graphrag/index/emit/factories.py b/func-app/graphrag/index/emit/factories.py new file mode 100644 index 0000000000..1c4e218785 --- /dev/null +++ b/func-app/graphrag/index/emit/factories.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Table Emitter Factories.""" + +from graphrag.config.models.graphdb_config import GraphDBConfig +from graphrag.common.storage import PipelineStorage +from graphrag.index.typing import ErrorHandlerFn + +from .csv_table_emitter import CSVTableEmitter +from .json_table_emitter import JsonTableEmitter +from .parquet_table_emitter import ParquetTableEmitter +from .graph_db_emitter import GraphDBEmitter +from .table_emitter import TableEmitter +from .types import TableEmitterType + +def create_table_emitter( + emitter_type: TableEmitterType, storage: PipelineStorage, on_error: ErrorHandlerFn, graphdb_params: GraphDBConfig|None = None, context_id: str|None = None +) -> TableEmitter: + """Create a table emitter based on the specified type.""" + match emitter_type: + case TableEmitterType.Json: + return JsonTableEmitter(storage) + case TableEmitterType.Parquet: + return ParquetTableEmitter(storage, on_error) + case TableEmitterType.CSV: + return CSVTableEmitter(storage) + case TableEmitterType.Graphdb: + return GraphDBEmitter(graphdb_params,context_id) + case _: + msg = f"Unsupported table emitter type: {emitter_type}" + raise ValueError(msg) + + +def create_table_emitters( + emitter_types: list[TableEmitterType], + storage: PipelineStorage, + on_error: ErrorHandlerFn, + graphdb_params: GraphDBConfig|None = None, + context_id: str|None = None, +) -> list[TableEmitter]: + """Create a list of table emitters based on the specified types.""" + return [ + create_table_emitter(emitter_type, storage, on_error, graphdb_params,context_id) + for emitter_type in emitter_types + ] diff --git a/func-app/graphrag/index/emit/graph_db_emitter.py b/func-app/graphrag/index/emit/graph_db_emitter.py new file mode 100644 index 0000000000..d8018ee678 --- /dev/null +++ b/func-app/graphrag/index/emit/graph_db_emitter.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphDBEmitter module.""" + +import pandas as pd +from common.graph_db_client import GraphDBClient +from .table_emitter import TableEmitter +from graphrag.config.models.graphdb_config import GraphDBConfig + +class GraphDBEmitter(TableEmitter): + """Graph DB Emitter.""" + + def __init__(self, graph_db_params: GraphDBConfig|None,context_id: str|None): + self.graph_db_client = GraphDBClient(graph_db_params,context_id) + self.allowed_workflows = ['create_final_entities','create_final_relationships'] + + async def emit(self, name: str, data: pd.DataFrame) -> None: + if name not in self.allowed_workflows: + return + if name == 'create_final_entities': + self.graph_db_client.write_vertices(data) + if name == 'create_final_relationships': + self.graph_db_client.write_edges(data) \ No newline at end of file diff --git a/func-app/graphrag/index/emit/json_table_emitter.py b/func-app/graphrag/index/emit/json_table_emitter.py new file mode 100644 index 0000000000..39f936b781 --- /dev/null +++ b/func-app/graphrag/index/emit/json_table_emitter.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""JsonTableEmitter module.""" + +import logging + +import pandas as pd + +from graphrag.common.storage import PipelineStorage + +from .table_emitter import TableEmitter + +log = logging.getLogger(__name__) + + +class JsonTableEmitter(TableEmitter): + """JsonTableEmitter class.""" + + _storage: PipelineStorage + + def __init__(self, storage: PipelineStorage): + """Create a new Json Table Emitter.""" + self._storage = storage + + async def emit(self, name: str, data: pd.DataFrame) -> None: + """Emit a dataframe to storage.""" + filename = f"{name}.json" + + log.info("emitting JSON table %s", filename) + await self._storage.set( + filename, + data.to_json(orient="records", lines=True, force_ascii=False), + ) diff --git a/func-app/graphrag/index/emit/parquet_table_emitter.py b/func-app/graphrag/index/emit/parquet_table_emitter.py new file mode 100644 index 0000000000..aa6dd38f96 --- /dev/null +++ b/func-app/graphrag/index/emit/parquet_table_emitter.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""ParquetTableEmitter module.""" + +import logging +import traceback + +import pandas as pd +from pyarrow.lib import ArrowInvalid, ArrowTypeError + +from graphrag.common.storage import PipelineStorage +from graphrag.index.typing import ErrorHandlerFn + +from .table_emitter import TableEmitter + +log = logging.getLogger(__name__) + + +class ParquetTableEmitter(TableEmitter): + """ParquetTableEmitter class.""" + + _storage: PipelineStorage + _on_error: ErrorHandlerFn + + def __init__( + self, + storage: PipelineStorage, + on_error: ErrorHandlerFn, + ): + """Create a new Parquet Table Emitter.""" + self._storage = storage + self._on_error = on_error + + async def emit(self, name: str, data: pd.DataFrame) -> None: + """Emit a dataframe to storage.""" + filename = f"{name}.parquet" + log.info("emitting parquet table %s", filename) + try: + await self._storage.set(filename, data.to_parquet()) + except ArrowTypeError as e: + log.exception("Error while emitting parquet table") + self._on_error( + e, + traceback.format_exc(), + None, + ) + except ArrowInvalid as e: + log.exception("Error while emitting parquet table") + self._on_error( + e, + traceback.format_exc(), + None, + ) diff --git a/func-app/graphrag/index/emit/table_emitter.py b/func-app/graphrag/index/emit/table_emitter.py new file mode 100644 index 0000000000..2161eeb523 --- /dev/null +++ b/func-app/graphrag/index/emit/table_emitter.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""TableEmitter protocol for emitting tables to a destination.""" + +from typing import Protocol + +import pandas as pd + + +class TableEmitter(Protocol): + """TableEmitter protocol for emitting tables to a destination.""" + + async def emit(self, name: str, data: pd.DataFrame) -> None: + """Emit a dataframe to storage.""" diff --git a/func-app/graphrag/index/emit/types.py b/func-app/graphrag/index/emit/types.py new file mode 100644 index 0000000000..0b0ff88541 --- /dev/null +++ b/func-app/graphrag/index/emit/types.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Table Emitter Types.""" + +from enum import Enum + + +class TableEmitterType(str, Enum): + """Table Emitter Types.""" + + Json = "json" + Parquet = "parquet" + CSV = "csv" + Graphdb = "graphdb" diff --git a/func-app/graphrag/index/errors.py b/func-app/graphrag/index/errors.py new file mode 100644 index 0000000000..430cf27d0f --- /dev/null +++ b/func-app/graphrag/index/errors.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG indexing error types.""" + + +class NoWorkflowsDefinedError(ValueError): + """Exception for no workflows defined.""" + + def __init__(self): + super().__init__("No workflows defined.") + + +class UndefinedWorkflowError(ValueError): + """Exception for invalid verb input.""" + + def __init__(self): + super().__init__("Workflow name is undefined.") + + +class UnknownWorkflowError(ValueError): + """Exception for invalid verb input.""" + + def __init__(self, name: str): + super().__init__(f"Unknown workflow: {name}") diff --git a/func-app/graphrag/index/graph/__init__.py b/func-app/graphrag/index/graph/__init__.py new file mode 100644 index 0000000000..cb26e59595 --- /dev/null +++ b/func-app/graphrag/index/graph/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph package root.""" diff --git a/func-app/graphrag/index/graph/embedding/__init__.py b/func-app/graphrag/index/graph/embedding/__init__.py new file mode 100644 index 0000000000..0ea2d085f1 --- /dev/null +++ b/func-app/graphrag/index/graph/embedding/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph embedding package root.""" + +from .embedding import NodeEmbeddings, embed_nod2vec + +__all__ = ["NodeEmbeddings", "embed_nod2vec"] diff --git a/func-app/graphrag/index/graph/embedding/embedding.py b/func-app/graphrag/index/graph/embedding/embedding.py new file mode 100644 index 0000000000..267a190f91 --- /dev/null +++ b/func-app/graphrag/index/graph/embedding/embedding.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utilities to generate graph embeddings.""" + +from dataclasses import dataclass + +import graspologic as gc +import networkx as nx +import numpy as np + + +@dataclass +class NodeEmbeddings: + """Node embeddings class definition.""" + + nodes: list[str] + embeddings: np.ndarray + + +def embed_nod2vec( + graph: nx.Graph | nx.DiGraph, + dimensions: int = 1536, + num_walks: int = 10, + walk_length: int = 40, + window_size: int = 2, + iterations: int = 3, + random_seed: int = 86, +) -> NodeEmbeddings: + """Generate node embeddings using Node2Vec.""" + # generate embedding + lcc_tensors = gc.embed.node2vec_embed( # type: ignore + graph=graph, + dimensions=dimensions, + window_size=window_size, + iterations=iterations, + num_walks=num_walks, + walk_length=walk_length, + random_seed=random_seed, + ) + return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1]) diff --git a/func-app/graphrag/index/graph/extractors/__init__.py b/func-app/graphrag/index/graph/extractors/__init__.py new file mode 100644 index 0000000000..9168d5e207 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph extractors package root.""" + +from .claims import CLAIM_EXTRACTION_PROMPT, ClaimExtractor +from .community_reports import ( + COMMUNITY_REPORT_PROMPT, + CommunityReportsExtractor, +) +from .graph import GraphExtractionResult, GraphExtractor + +__all__ = [ + "CLAIM_EXTRACTION_PROMPT", + "COMMUNITY_REPORT_PROMPT", + "ClaimExtractor", + "CommunityReportsExtractor", + "GraphExtractionResult", + "GraphExtractor", +] diff --git a/func-app/graphrag/index/graph/extractors/claims/__init__.py b/func-app/graphrag/index/graph/extractors/claims/__init__.py new file mode 100644 index 0000000000..3977c8ff83 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/claims/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph extractors claims package root.""" + +from .claim_extractor import ClaimExtractor +from .prompts import CLAIM_EXTRACTION_PROMPT + +__all__ = ["CLAIM_EXTRACTION_PROMPT", "ClaimExtractor"] diff --git a/func-app/graphrag/index/graph/extractors/claims/claim_extractor.py b/func-app/graphrag/index/graph/extractors/claims/claim_extractor.py new file mode 100644 index 0000000000..c7e76d5067 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/claims/claim_extractor.py @@ -0,0 +1,248 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'ClaimExtractorResult' and 'ClaimExtractor' models.""" + +import logging +import traceback +from dataclasses import dataclass +from typing import Any + +import tiktoken + +import graphrag.config.defaults as defs +from graphrag.index.typing import ErrorHandlerFn +from graphrag.llm import CompletionLLM + +from .prompts import ( + CLAIM_EXTRACTION_PROMPT, + CONTINUE_PROMPT, + LOOP_PROMPT, +) + +DEFAULT_TUPLE_DELIMITER = "<|>" +DEFAULT_RECORD_DELIMITER = "##" +DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" +log = logging.getLogger(__name__) + + +@dataclass +class ClaimExtractorResult: + """Claim extractor result class definition.""" + + output: list[dict] + source_docs: dict[str, Any] + + +class ClaimExtractor: + """Claim extractor class definition.""" + + _llm: CompletionLLM + _extraction_prompt: str + _summary_prompt: str + _output_formatter_prompt: str + _input_text_key: str + _input_entity_spec_key: str + _input_claim_description_key: str + _tuple_delimiter_key: str + _record_delimiter_key: str + _completion_delimiter_key: str + _max_gleanings: int + _on_error: ErrorHandlerFn + + def __init__( + self, + llm_invoker: CompletionLLM, + extraction_prompt: str | None = None, + input_text_key: str | None = None, + input_entity_spec_key: str | None = None, + input_claim_description_key: str | None = None, + input_resolved_entities_key: str | None = None, + tuple_delimiter_key: str | None = None, + record_delimiter_key: str | None = None, + completion_delimiter_key: str | None = None, + encoding_model: str | None = None, + max_gleanings: int | None = None, + on_error: ErrorHandlerFn | None = None, + ): + """Init method definition.""" + self._llm = llm_invoker + self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT + self._input_text_key = input_text_key or "input_text" + self._input_entity_spec_key = input_entity_spec_key or "entity_specs" + self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" + self._record_delimiter_key = record_delimiter_key or "record_delimiter" + self._completion_delimiter_key = ( + completion_delimiter_key or "completion_delimiter" + ) + self._input_claim_description_key = ( + input_claim_description_key or "claim_description" + ) + self._input_resolved_entities_key = ( + input_resolved_entities_key or "resolved_entities" + ) + self._max_gleanings = ( + max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS + ) + self._on_error = on_error or (lambda _e, _s, _d: None) + + # Construct the looping arguments + encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") + yes = encoding.encode("YES") + no = encoding.encode("NO") + self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} + + async def __call__( + self, inputs: dict[str, Any], prompt_variables: dict | None = None + ) -> ClaimExtractorResult: + """Call method definition.""" + if prompt_variables is None: + prompt_variables = {} + texts = inputs[self._input_text_key] + entity_spec = str(inputs[self._input_entity_spec_key]) + claim_description = inputs[self._input_claim_description_key] + resolved_entities = inputs.get(self._input_resolved_entities_key, {}) + source_doc_map = {} + + prompt_args = { + self._input_entity_spec_key: entity_spec, + self._input_claim_description_key: claim_description, + self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) + or DEFAULT_TUPLE_DELIMITER, + self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) + or DEFAULT_RECORD_DELIMITER, + self._completion_delimiter_key: prompt_variables.get( + self._completion_delimiter_key + ) + or DEFAULT_COMPLETION_DELIMITER, + } + + all_claims: list[dict] = [] + for doc_index, text in enumerate(texts): + document_id = f"d{doc_index}" + try: + claims = await self._process_document(prompt_args, text, doc_index) + all_claims += [ + self._clean_claim(c, document_id, resolved_entities) for c in claims + ] + source_doc_map[document_id] = text + except Exception as e: + log.exception("error extracting claim") + self._on_error( + e, + traceback.format_exc(), + {"doc_index": doc_index, "text": text}, + ) + continue + + return ClaimExtractorResult( + output=all_claims, + source_docs=source_doc_map, + ) + + def _clean_claim( + self, claim: dict, document_id: str, resolved_entities: dict + ) -> dict: + # clean the parsed claims to remove any claims with status = False + obj = claim.get("object_id", claim.get("object")) + subject = claim.get("subject_id", claim.get("subject")) + + # If subject or object in resolved entities, then replace with resolved entity + obj = resolved_entities.get(obj, obj) + subject = resolved_entities.get(subject, subject) + claim["object_id"] = obj + claim["subject_id"] = subject + claim["doc_id"] = document_id + return claim + + async def _process_document( + self, prompt_args: dict, doc, doc_index: int + ) -> list[dict]: + record_delimiter = prompt_args.get( + self._record_delimiter_key, DEFAULT_RECORD_DELIMITER + ) + completion_delimiter = prompt_args.get( + self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER + ) + + response = await self._llm( + self._extraction_prompt, + variables={ + self._input_text_key: doc, + **prompt_args, + }, + ) + results = response.output or "" + claims = results.strip().removesuffix(completion_delimiter) + + # Repeat to ensure we maximize entity count + for i in range(self._max_gleanings): + response = await self._llm( + CONTINUE_PROMPT, + name=f"extract-continuation-{i}", + history=response.history, + ) + extension = response.output or "" + claims += record_delimiter + extension.strip().removesuffix( + completion_delimiter + ) + + # If this isn't the last loop, check to see if we should continue + if i >= self._max_gleanings - 1: + break + + response = await self._llm( + LOOP_PROMPT, + name=f"extract-loopcheck-{i}", + history=response.history, + model_parameters=self._loop_args, + ) + if response.output != "YES": + break + + result = self._parse_claim_tuples(results, prompt_args) + for r in result: + r["doc_id"] = f"{doc_index}" + return result + + def _parse_claim_tuples( + self, claims: str, prompt_variables: dict + ) -> list[dict[str, Any]]: + """Parse claim tuples.""" + record_delimiter = prompt_variables.get( + self._record_delimiter_key, DEFAULT_RECORD_DELIMITER + ) + completion_delimiter = prompt_variables.get( + self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER + ) + tuple_delimiter = prompt_variables.get( + self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER + ) + + def pull_field(index: int, fields: list[str]) -> str | None: + return fields[index].strip() if len(fields) > index else None + + result: list[dict[str, Any]] = [] + claims_values = ( + claims.strip().removesuffix(completion_delimiter).split(record_delimiter) + ) + for claim in claims_values: + claim = claim.strip().removeprefix("(").removesuffix(")") + + # Ignore the completion delimiter + if claim == completion_delimiter: + continue + + claim_fields = claim.split(tuple_delimiter) + result.append({ + "subject_id": pull_field(0, claim_fields), + "object_id": pull_field(1, claim_fields), + "type": pull_field(2, claim_fields), + "status": pull_field(3, claim_fields), + "start_date": pull_field(4, claim_fields), + "end_date": pull_field(5, claim_fields), + "description": pull_field(6, claim_fields), + "source_text": pull_field(7, claim_fields), + "doc_id": pull_field(8, claim_fields), + }) + return result diff --git a/func-app/graphrag/index/graph/extractors/claims/prompts.py b/func-app/graphrag/index/graph/extractors/claims/prompts.py new file mode 100644 index 0000000000..05b3153c20 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/claims/prompts.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing prompts definition.""" + +CLAIM_EXTRACTION_PROMPT = """ +-Target activity- +You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document. + +-Goal- +Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities. + +-Steps- +1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types. +2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim. +For each claim, extract the following information: +- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1. +- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**. +- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type +- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified. +- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references. +- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**. +- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim. + +Format each claim as ({tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in English as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +-Examples- +Example 1: +Entity specification: organization +Claim description: red flags associated with an entity +Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. +Output: + +(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +{completion_delimiter} + +Example 2: +Entity specification: Company A, Person C +Claim description: red flags associated with an entity +Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. +Output: + +(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +{record_delimiter} +(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015) +{completion_delimiter} + +-Real Data- +Use the following input for your answer. +Entity specification: {entity_specs} +Claim description: {claim_description} +Text: {input_text} +Output:""" + + +CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format:\n" +LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n" diff --git a/func-app/graphrag/index/graph/extractors/community_reports/__init__.py b/func-app/graphrag/index/graph/extractors/community_reports/__init__.py new file mode 100644 index 0000000000..599f56d60f --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine community reports package root.""" + +import graphrag.index.graph.extractors.community_reports.schemas as schemas + +from .build_mixed_context import build_mixed_context +from .community_reports_extractor import CommunityReportsExtractor +from .prep_community_report_context import prep_community_report_context +from .prompts import COMMUNITY_REPORT_PROMPT +from .sort_context import sort_context +from .utils import ( + filter_claims_to_nodes, + filter_edges_to_nodes, + filter_nodes_to_level, + get_levels, + set_context_exceeds_flag, + set_context_size, +) + +__all__ = [ + "COMMUNITY_REPORT_PROMPT", + "CommunityReportsExtractor", + "build_mixed_context", + "filter_claims_to_nodes", + "filter_edges_to_nodes", + "filter_nodes_to_level", + "get_levels", + "prep_community_report_context", + "schemas", + "set_context_exceeds_flag", + "set_context_size", + "sort_context", +] diff --git a/func-app/graphrag/index/graph/extractors/community_reports/build_mixed_context.py b/func-app/graphrag/index/graph/extractors/community_reports/build_mixed_context.py new file mode 100644 index 0000000000..ad9e2a8447 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/build_mixed_context.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""A module containing the build_mixed_context method definition.""" + +import pandas as pd + +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.query.llm.text_utils import num_tokens + +from .sort_context import sort_context + + +def build_mixed_context(context: list[dict], max_tokens: int) -> str: + """ + Build parent context by concatenating all sub-communities' contexts. + + If the context exceeds the limit, we use sub-community reports instead. + """ + sorted_context = sorted( + context, key=lambda x: x[schemas.CONTEXT_SIZE], reverse=True + ) + + # replace local context with sub-community reports, starting from the biggest sub-community + substitute_reports = [] + final_local_contexts = [] + exceeded_limit = True + context_string = "" + + for idx, sub_community_context in enumerate(sorted_context): + if exceeded_limit: + if sub_community_context[schemas.FULL_CONTENT]: + substitute_reports.append({ + schemas.COMMUNITY_ID: sub_community_context[schemas.SUB_COMMUNITY], + schemas.FULL_CONTENT: sub_community_context[schemas.FULL_CONTENT], + }) + else: + # this sub-community has no report, so we will use its local context + final_local_contexts.extend(sub_community_context[schemas.ALL_CONTEXT]) + continue + + # add local context for the remaining sub-communities + remaining_local_context = [] + for rid in range(idx + 1, len(sorted_context)): + remaining_local_context.extend(sorted_context[rid][schemas.ALL_CONTEXT]) + new_context_string = sort_context( + local_context=remaining_local_context + final_local_contexts, + sub_community_reports=substitute_reports, + ) + if num_tokens(new_context_string) <= max_tokens: + exceeded_limit = False + context_string = new_context_string + break + + if exceeded_limit: + # if all sub-community reports exceed the limit, we add reports until context is full + substitute_reports = [] + for sub_community_context in sorted_context: + substitute_reports.append({ + schemas.COMMUNITY_ID: sub_community_context[schemas.SUB_COMMUNITY], + schemas.FULL_CONTENT: sub_community_context[schemas.FULL_CONTENT], + }) + new_context_string = pd.DataFrame(substitute_reports).to_csv( + index=False, sep="," + ) + if num_tokens(new_context_string) > max_tokens: + break + + context_string = new_context_string + return context_string diff --git a/func-app/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py b/func-app/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py new file mode 100644 index 0000000000..309336fee7 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'CommunityReportsResult' and 'CommunityReportsExtractor' models.""" + +import logging +import traceback +from dataclasses import dataclass +from typing import Any + +from graphrag.index.typing import ErrorHandlerFn +from graphrag.index.utils import dict_has_keys_with_types +from graphrag.llm import CompletionLLM + +from .prompts import COMMUNITY_REPORT_PROMPT + +log = logging.getLogger(__name__) + + +@dataclass +class CommunityReportsResult: + """Community reports result class definition.""" + + output: str + structured_output: dict + + +class CommunityReportsExtractor: + """Community reports extractor class definition.""" + + _llm: CompletionLLM + _input_text_key: str + _extraction_prompt: str + _output_formatter_prompt: str + _on_error: ErrorHandlerFn + _max_report_length: int + + def __init__( + self, + llm_invoker: CompletionLLM, + input_text_key: str | None = None, + extraction_prompt: str | None = None, + on_error: ErrorHandlerFn | None = None, + max_report_length: int | None = None, + ): + """Init method definition.""" + self._llm = llm_invoker + self._input_text_key = input_text_key or "input_text" + self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT + self._on_error = on_error or (lambda _e, _s, _d: None) + self._max_report_length = max_report_length or 1500 + + async def __call__(self, inputs: dict[str, Any]): + """Call method definition.""" + output = None + try: + response = ( + await self._llm( + self._extraction_prompt, + json=True, + name="create_community_report", + variables={self._input_text_key: inputs[self._input_text_key]}, + is_response_valid=lambda x: dict_has_keys_with_types( + x, + [ + ("title", str), + ("summary", str), + ("findings", list), + ("rating", float), + ("rating_explanation", str), + ], + ), + model_parameters={"max_tokens": self._max_report_length}, + ) + or {} + ) + output = response.json or {} + except Exception as e: + log.exception("error generating community report") + self._on_error(e, traceback.format_exc(), None) + output = {} + + text_output = self._get_text_output(output) + return CommunityReportsResult( + structured_output=output, + output=text_output, + ) + + def _get_text_output(self, parsed_output: dict) -> str: + title = parsed_output.get("title", "Report") + summary = parsed_output.get("summary", "") + findings = parsed_output.get("findings", []) + + def finding_summary(finding: dict): + if isinstance(finding, str): + return finding + return finding.get("summary") + + def finding_explanation(finding: dict): + if isinstance(finding, str): + return "" + return finding.get("explanation") + + report_sections = "\n\n".join( + f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings + ) + return f"# {title}\n\n{summary}\n\n{report_sections}" diff --git a/func-app/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py b/func-app/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py new file mode 100644 index 0000000000..2ec4222024 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py @@ -0,0 +1,181 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_community_reports and load_strategy methods definition.""" + +import logging +from typing import cast + +import pandas as pd + +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.index.utils.dataframes import ( + antijoin, + drop_columns, + join, + select, + transform_series, + union, + where_column_equals, +) + +from .build_mixed_context import build_mixed_context +from .sort_context import sort_context +from .utils import set_context_size + +log = logging.getLogger(__name__) + + +def prep_community_report_context( + report_df: pd.DataFrame | None, + community_hierarchy_df: pd.DataFrame, + local_context_df: pd.DataFrame, + level: int | str, + max_tokens: int, +) -> pd.DataFrame: + """ + Prep context for each community in a given level. + + For each community: + - Check if local context fits within the limit, if yes use local context + - If local context exceeds the limit, iteratively replace local context with sub-community reports, starting from the biggest sub-community + """ + if report_df is None: + report_df = pd.DataFrame() + + level = int(level) + level_context_df = _at_level(level, local_context_df) + valid_context_df = _within_context(level_context_df) + invalid_context_df = _exceeding_context(level_context_df) + + # there is no report to substitute with, so we just trim the local context of the invalid context records + # this case should only happen at the bottom level of the community hierarchy where there are no sub-communities + if invalid_context_df.empty: + return valid_context_df + + if report_df.empty: + invalid_context_df[schemas.CONTEXT_STRING] = _sort_and_trim_context( + invalid_context_df, max_tokens + ) + set_context_size(invalid_context_df) + invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = 0 + return union(valid_context_df, invalid_context_df) + + level_context_df = _antijoin_reports(level_context_df, report_df) + + # for each invalid context, we will try to substitute with sub-community reports + # first get local context and report (if available) for each sub-community + sub_context_df = _get_subcontext_df(level + 1, report_df, local_context_df) + community_df = _get_community_df( + level, invalid_context_df, sub_context_df, community_hierarchy_df, max_tokens + ) + + # handle any remaining invalid records that can't be subsituted with sub-community reports + # this should be rare, but if it happens, we will just trim the local context to fit the limit + remaining_df = _antijoin_reports(invalid_context_df, community_df) + remaining_df[schemas.CONTEXT_STRING] = _sort_and_trim_context( + remaining_df, max_tokens + ) + + result = union(valid_context_df, community_df, remaining_df) + set_context_size(result) + result[schemas.CONTEXT_EXCEED_FLAG] = 0 + return result + + +def _drop_community_level(df: pd.DataFrame) -> pd.DataFrame: + """Drop the community level column from the dataframe.""" + return drop_columns(df, schemas.COMMUNITY_LEVEL) + + +def _at_level(level: int, df: pd.DataFrame) -> pd.DataFrame: + """Return records at the given level.""" + return where_column_equals(df, schemas.COMMUNITY_LEVEL, level) + + +def _exceeding_context(df: pd.DataFrame) -> pd.DataFrame: + """Return records where the context exceeds the limit.""" + return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 1) + + +def _within_context(df: pd.DataFrame) -> pd.DataFrame: + """Return records where the context is within the limit.""" + return where_column_equals(df, schemas.CONTEXT_EXCEED_FLAG, 0) + + +def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame: + """Return records in df that are not in reports.""" + return antijoin(df, reports, schemas.NODE_COMMUNITY) + + +def _sort_and_trim_context(df: pd.DataFrame, max_tokens: int) -> pd.Series: + """Sort and trim context to fit the limit.""" + series = cast(pd.Series, df[schemas.ALL_CONTEXT]) + return transform_series(series, lambda x: sort_context(x, max_tokens=max_tokens)) + + +def _build_mixed_context(df: pd.DataFrame, max_tokens: int) -> pd.Series: + """Sort and trim context to fit the limit.""" + series = cast(pd.Series, df[schemas.ALL_CONTEXT]) + return transform_series( + series, lambda x: build_mixed_context(x, max_tokens=max_tokens) + ) + + +def _get_subcontext_df( + level: int, report_df: pd.DataFrame, local_context_df: pd.DataFrame +) -> pd.DataFrame: + """Get sub-community context for each community.""" + sub_report_df = _drop_community_level(_at_level(level, report_df)) + sub_context_df = _at_level(level, local_context_df) + sub_context_df = join(sub_context_df, sub_report_df, schemas.NODE_COMMUNITY) + sub_context_df.rename( + columns={schemas.NODE_COMMUNITY: schemas.SUB_COMMUNITY}, inplace=True + ) + return sub_context_df + + +def _get_community_df( + level: int, + invalid_context_df: pd.DataFrame, + sub_context_df: pd.DataFrame, + community_hierarchy_df: pd.DataFrame, + max_tokens: int, +) -> pd.DataFrame: + """Get community context for each community.""" + # collect all sub communities' contexts for each community + community_df = _drop_community_level(_at_level(level, community_hierarchy_df)) + invalid_community_ids = select(invalid_context_df, schemas.NODE_COMMUNITY) + subcontext_selection = select( + sub_context_df, + schemas.SUB_COMMUNITY, + schemas.FULL_CONTENT, + schemas.ALL_CONTEXT, + schemas.CONTEXT_SIZE, + ) + + invalid_communities = join( + community_df, invalid_community_ids, schemas.NODE_COMMUNITY, "inner" + ) + community_df = join( + invalid_communities, subcontext_selection, schemas.SUB_COMMUNITY + ) + community_df[schemas.ALL_CONTEXT] = community_df.apply( + lambda x: { + schemas.SUB_COMMUNITY: x[schemas.SUB_COMMUNITY], + schemas.ALL_CONTEXT: x[schemas.ALL_CONTEXT], + schemas.FULL_CONTENT: x[schemas.FULL_CONTENT], + schemas.CONTEXT_SIZE: x[schemas.CONTEXT_SIZE], + }, + axis=1, + ) + community_df = ( + community_df.groupby(schemas.NODE_COMMUNITY) + .agg({schemas.ALL_CONTEXT: list}) + .reset_index() + ) + community_df[schemas.CONTEXT_STRING] = _build_mixed_context( + community_df, max_tokens + ) + community_df[schemas.COMMUNITY_LEVEL] = level + return community_df diff --git a/func-app/graphrag/index/graph/extractors/community_reports/prompts.py b/func-app/graphrag/index/graph/extractors/community_reports/prompts.py new file mode 100644 index 0000000000..35ca38bc8b --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/prompts.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""A file containing prompts definition.""" + +COMMUNITY_REPORT_PROMPT = """ +You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network. + +# Goal +Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims. + +# Report Structure + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. +- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format: + {{ + "title": , + "summary": , + "rating": , + "rating_explanation": , + "findings": [ + {{ + "summary":, + "explanation": + }}, + {{ + "summary":, + "explanation": + }} + ] + }} + +# Grounding Rules + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." + +where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +# Example Input +----------- +Text: + +Entities + +id,entity,description +5,VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March +6,HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza + +Relationships + +id,source,target,description +37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March +38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza +39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza +40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza +41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march +43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March + +Output: +{{ + "title": "Verdant Oasis Plaza and Unity March", + "summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.", + "rating": 5.0, + "rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.", + "findings": [ + {{ + "summary": "Verdant Oasis Plaza as the central location", + "explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes. [Data: Entities (5), Relationships (37, 38, 39, 40, 41,+more)]" + }}, + {{ + "summary": "Harmony Assembly's role in the community", + "explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community. [Data: Entities(6), Relationships (38, 43)]" + }}, + {{ + "summary": "Unity March as a significant event", + "explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community. [Data: Relationships (39)]" + }}, + {{ + "summary": "Role of Tribune Spotlight", + "explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved. [Data: Relationships (40)]" + }} + ] +}} + + +# Real Data + +Use the following text for your answer. Do not make anything up in your answer. + +Text: +{input_text} + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. +- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format: + {{ + "title": , + "summary": , + "rating": , + "rating_explanation": , + "findings": [ + {{ + "summary":, + "explanation": + }}, + {{ + "summary":, + "explanation": + }} + ] + }} + +# Grounding Rules + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." + +where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + +Output:""" diff --git a/func-app/graphrag/index/graph/extractors/community_reports/schemas.py b/func-app/graphrag/index/graph/extractors/community_reports/schemas.py new file mode 100644 index 0000000000..8e89e0273c --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/schemas.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Common field name definitions for community reports.""" + +# POST-PREP NODE TABLE SCHEMA +NODE_ID = "human_readable_id" +NODE_NAME = "title" +NODE_DESCRIPTION = "description" +NODE_DEGREE = "degree" +NODE_DETAILS = "node_details" +NODE_COMMUNITY = "community" +NODE_LEVEL = "level" + +# POST-PREP EDGE TABLE SCHEMA +EDGE_ID = "human_readable_id" +EDGE_SOURCE = "source" +EDGE_TARGET = "target" +EDGE_DESCRIPTION = "description" +EDGE_DEGREE = "rank" +EDGE_DETAILS = "edge_details" +EDGE_WEIGHT = "weight" + +# POST-PREP CLAIM TABLE SCHEMA +CLAIM_ID = "human_readable_id" +CLAIM_SUBJECT = "subject_id" +CLAIM_TYPE = "type" +CLAIM_STATUS = "status" +CLAIM_DESCRIPTION = "description" +CLAIM_DETAILS = "claim_details" + +# COMMUNITY HIERARCHY TABLE SCHEMA +SUB_COMMUNITY = "sub_communitty" +SUB_COMMUNITY_SIZE = "sub_community_size" +COMMUNITY_LEVEL = "level" + +# COMMUNITY CONTEXT TABLE SCHEMA +ALL_CONTEXT = "all_context" +CONTEXT_STRING = "context_string" +CONTEXT_SIZE = "context_size" +CONTEXT_EXCEED_FLAG = "context_exceed_limit" + +# COMMUNITY REPORT TABLE SCHEMA +REPORT_ID = "id" +COMMUNITY_ID = "id" +COMMUNITY_LEVEL = "level" +TITLE = "title" +SUMMARY = "summary" +FINDINGS = "findings" +RATING = "rank" +EXPLANATION = "rating_explanation" +FULL_CONTENT = "full_content" +FULL_CONTENT_JSON = "full_content_json" diff --git a/func-app/graphrag/index/graph/extractors/community_reports/sort_context.py b/func-app/graphrag/index/graph/extractors/community_reports/sort_context.py new file mode 100644 index 0000000000..811cb7e95c --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/sort_context.py @@ -0,0 +1,156 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Sort context by degree in descending order.""" + +import pandas as pd + +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.query.llm.text_utils import num_tokens + + +def sort_context( + local_context: list[dict], + sub_community_reports: list[dict] | None = None, + max_tokens: int | None = None, + node_id_column: str = schemas.NODE_ID, + node_name_column: str = schemas.NODE_NAME, + node_details_column: str = schemas.NODE_DETAILS, + edge_id_column: str = schemas.EDGE_ID, + edge_details_column: str = schemas.EDGE_DETAILS, + edge_degree_column: str = schemas.EDGE_DEGREE, + edge_source_column: str = schemas.EDGE_SOURCE, + edge_target_column: str = schemas.EDGE_TARGET, + claim_id_column: str = schemas.CLAIM_ID, + claim_details_column: str = schemas.CLAIM_DETAILS, + community_id_column: str = schemas.COMMUNITY_ID, +) -> str: + """Sort context by degree in descending order. + + If max tokens is provided, we will return the context string that fits within the token limit. + """ + + def _get_context_string( + entities: list[dict], + edges: list[dict], + claims: list[dict], + sub_community_reports: list[dict] | None = None, + ) -> str: + """Concatenate structured data into a context string.""" + contexts = [] + if sub_community_reports: + sub_community_reports = [ + report + for report in sub_community_reports + if community_id_column in report + and report[community_id_column] + and str(report[community_id_column]).strip() != "" + ] + report_df = pd.DataFrame(sub_community_reports).drop_duplicates() + if not report_df.empty: + if report_df[community_id_column].dtype == float: + report_df[community_id_column] = report_df[ + community_id_column + ].astype(int) + report_string = ( + f"----Reports-----\n{report_df.to_csv(index=False, sep=',')}" + ) + contexts.append(report_string) + + entities = [ + entity + for entity in entities + if node_id_column in entity + and entity[node_id_column] + and str(entity[node_id_column]).strip() != "" + ] + entity_df = pd.DataFrame(entities).drop_duplicates() + if not entity_df.empty: + if entity_df[node_id_column].dtype == float: + entity_df[node_id_column] = entity_df[node_id_column].astype(int) + entity_string = ( + f"-----Entities-----\n{entity_df.to_csv(index=False, sep=',')}" + ) + contexts.append(entity_string) + + if claims and len(claims) > 0: + claims = [ + claim + for claim in claims + if claim_id_column in claim + and claim[claim_id_column] + and str(claim[claim_id_column]).strip() != "" + ] + claim_df = pd.DataFrame(claims).drop_duplicates() + if not claim_df.empty: + if claim_df[claim_id_column].dtype == float: + claim_df[claim_id_column] = claim_df[claim_id_column].astype(int) + claim_string = ( + f"-----Claims-----\n{claim_df.to_csv(index=False, sep=',')}" + ) + contexts.append(claim_string) + + edges = [ + edge + for edge in edges + if edge_id_column in edge + and edge[edge_id_column] + and str(edge[edge_id_column]).strip() != "" + ] + edge_df = pd.DataFrame(edges).drop_duplicates() + if not edge_df.empty: + if edge_df[edge_id_column].dtype == float: + edge_df[edge_id_column] = edge_df[edge_id_column].astype(int) + edge_string = ( + f"-----Relationships-----\n{edge_df.to_csv(index=False, sep=',')}" + ) + contexts.append(edge_string) + + return "\n\n".join(contexts) + + # sort node details by degree in descending order + edges = [] + node_details = {} + claim_details = {} + + for record in local_context: + node_name = record[node_name_column] + record_edges = record.get(edge_details_column, []) + record_edges = [e for e in record_edges if not pd.isna(e)] + record_node_details = record[node_details_column] + record_claims = record.get(claim_details_column, []) + record_claims = [c for c in record_claims if not pd.isna(c)] + + edges.extend(record_edges) + node_details[node_name] = record_node_details + claim_details[node_name] = record_claims + + edges = [edge for edge in edges if isinstance(edge, dict)] + edges = sorted(edges, key=lambda x: x[edge_degree_column], reverse=True) + + sorted_edges = [] + sorted_nodes = [] + sorted_claims = [] + context_string = "" + for edge in edges: + source_details = node_details.get(edge[edge_source_column], {}) + target_details = node_details.get(edge[edge_target_column], {}) + sorted_nodes.extend([source_details, target_details]) + sorted_edges.append(edge) + source_claims = claim_details.get(edge[edge_source_column], []) + target_claims = claim_details.get(edge[edge_target_column], []) + sorted_claims.extend(source_claims if source_claims else []) + sorted_claims.extend(target_claims if source_claims else []) + if max_tokens: + new_context_string = _get_context_string( + sorted_nodes, sorted_edges, sorted_claims, sub_community_reports + ) + if num_tokens(context_string) > max_tokens: + break + context_string = new_context_string + + if context_string == "": + return _get_context_string( + sorted_nodes, sorted_edges, sorted_claims, sub_community_reports + ) + + return context_string diff --git a/func-app/graphrag/index/graph/extractors/community_reports/utils.py b/func-app/graphrag/index/graph/extractors/community_reports/utils.py new file mode 100644 index 0000000000..b5fc9af9b8 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/community_reports/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing community report generation utilities.""" + +from typing import cast + +import pandas as pd + +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.query.llm.text_utils import num_tokens + + +def set_context_size(df: pd.DataFrame) -> None: + """Measure the number of tokens in the context.""" + df[schemas.CONTEXT_SIZE] = df[schemas.CONTEXT_STRING].apply(lambda x: num_tokens(x)) + + +def set_context_exceeds_flag(df: pd.DataFrame, max_tokens: int) -> None: + """Set a flag to indicate if the context exceeds the limit.""" + df[schemas.CONTEXT_EXCEED_FLAG] = df[schemas.CONTEXT_SIZE].apply( + lambda x: x > max_tokens + ) + + +def get_levels(df: pd.DataFrame, level_column: str = schemas.NODE_LEVEL) -> list[int]: + """Get the levels of the communities.""" + result = sorted(df[level_column].fillna(-1).unique().tolist(), reverse=True) + return [r for r in result if r != -1] + + +def filter_nodes_to_level(node_df: pd.DataFrame, level: int) -> pd.DataFrame: + """Filter nodes to level.""" + return cast(pd.DataFrame, node_df[node_df[schemas.NODE_LEVEL] == level]) + + +def filter_edges_to_nodes(edge_df: pd.DataFrame, nodes: list[str]) -> pd.DataFrame: + """Filter edges to nodes.""" + return cast( + pd.DataFrame, + edge_df[ + edge_df[schemas.EDGE_SOURCE].isin(nodes) + & edge_df[schemas.EDGE_TARGET].isin(nodes) + ], + ) + + +def filter_claims_to_nodes(claims_df: pd.DataFrame, nodes: list[str]) -> pd.DataFrame: + """Filter edges to nodes.""" + return cast( + pd.DataFrame, + claims_df[claims_df[schemas.CLAIM_SUBJECT].isin(nodes)], + ) diff --git a/func-app/graphrag/index/graph/extractors/graph/__init__.py b/func-app/graphrag/index/graph/extractors/graph/__init__.py new file mode 100644 index 0000000000..94e03ab9f7 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/graph/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine unipartite graph package root.""" + +from .graph_extractor import ( + DEFAULT_ENTITY_TYPES, + GraphExtractionResult, + GraphExtractor, +) +from .prompts import GRAPH_EXTRACTION_PROMPT + +__all__ = [ + "DEFAULT_ENTITY_TYPES", + "GRAPH_EXTRACTION_PROMPT", + "GraphExtractionResult", + "GraphExtractor", +] diff --git a/func-app/graphrag/index/graph/extractors/graph/graph_extractor.py b/func-app/graphrag/index/graph/extractors/graph/graph_extractor.py new file mode 100644 index 0000000000..f1ba0011f9 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/graph/graph_extractor.py @@ -0,0 +1,305 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models.""" + +import logging +import numbers +import re +import traceback +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import networkx as nx +import tiktoken + +import graphrag.config.defaults as defs +from graphrag.index.typing import ErrorHandlerFn +from graphrag.index.utils import clean_str +from graphrag.llm import CompletionLLM + +from .prompts import CONTINUE_PROMPT, GRAPH_EXTRACTION_PROMPT, LOOP_PROMPT + +DEFAULT_TUPLE_DELIMITER = "<|>" +DEFAULT_RECORD_DELIMITER = "##" +DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" +DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] + + +@dataclass +class GraphExtractionResult: + """Unipartite graph extraction result class definition.""" + + output: nx.Graph + source_docs: dict[Any, Any] + + +class GraphExtractor: + """Unipartite graph extractor class definition.""" + + _llm: CompletionLLM + _join_descriptions: bool + _tuple_delimiter_key: str + _record_delimiter_key: str + _entity_types_key: str + _input_text_key: str + _completion_delimiter_key: str + _entity_name_key: str + _input_descriptions_key: str + _extraction_prompt: str + _summarization_prompt: str + _loop_args: dict[str, Any] + _max_gleanings: int + _on_error: ErrorHandlerFn + + def __init__( + self, + llm_invoker: CompletionLLM, + tuple_delimiter_key: str | None = None, + record_delimiter_key: str | None = None, + input_text_key: str | None = None, + entity_types_key: str | None = None, + completion_delimiter_key: str | None = None, + prompt: str | None = None, + join_descriptions=True, + encoding_model: str | None = None, + max_gleanings: int | None = None, + on_error: ErrorHandlerFn | None = None, + ): + """Init method definition.""" + # TODO: streamline construction + self._llm = llm_invoker + self._join_descriptions = join_descriptions + self._input_text_key = input_text_key or "input_text" + self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" + self._record_delimiter_key = record_delimiter_key or "record_delimiter" + self._completion_delimiter_key = ( + completion_delimiter_key or "completion_delimiter" + ) + self._entity_types_key = entity_types_key or "entity_types" + self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT + self._max_gleanings = ( + max_gleanings + if max_gleanings is not None + else defs.ENTITY_EXTRACTION_MAX_GLEANINGS + ) + self._on_error = on_error or (lambda _e, _s, _d: None) + + # Construct the looping arguments + encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") + yes = encoding.encode("YES") + no = encoding.encode("NO") + self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} + + async def __call__( + self, texts: list[str], prompt_variables: dict[str, Any] | None = None + ) -> GraphExtractionResult: + """Call method definition.""" + if prompt_variables is None: + prompt_variables = {} + all_records: dict[int, str] = {} + source_doc_map: dict[int, str] = {} + + # Wire defaults into the prompt variables + prompt_variables = { + **prompt_variables, + self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) + or DEFAULT_TUPLE_DELIMITER, + self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) + or DEFAULT_RECORD_DELIMITER, + self._completion_delimiter_key: prompt_variables.get( + self._completion_delimiter_key + ) + or DEFAULT_COMPLETION_DELIMITER, + self._entity_types_key: ",".join( + prompt_variables[self._entity_types_key] or DEFAULT_ENTITY_TYPES + ), + } + + for doc_index, text in enumerate(texts): + try: + # Invoke the entity extraction + result = await self._process_document(text, prompt_variables) + source_doc_map[doc_index] = text + all_records[doc_index] = result + except Exception as e: + logging.exception("error extracting graph") + self._on_error( + e, + traceback.format_exc(), + { + "doc_index": doc_index, + "text": text, + }, + ) + + output = await self._process_results( + all_records, + prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER), + prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), + ) + + return GraphExtractionResult( + output=output, + source_docs=source_doc_map, + ) + + async def _process_document( + self, text: str, prompt_variables: dict[str, str] + ) -> str: + response = await self._llm( + self._extraction_prompt, + variables={ + **prompt_variables, + self._input_text_key: text, + }, + ) + results = response.output or "" + + # Repeat to ensure we maximize entity count + for i in range(self._max_gleanings): + response = await self._llm( + CONTINUE_PROMPT, + name=f"extract-continuation-{i}", + history=response.history, + ) + results += response.output or "" + + # if this is the final glean, don't bother updating the continuation flag + if i >= self._max_gleanings - 1: + break + + response = await self._llm( + LOOP_PROMPT, + name=f"extract-loopcheck-{i}", + history=response.history, + model_parameters=self._loop_args, + ) + if response.output != "YES": + break + + return results + + async def _process_results( + self, + results: dict[int, str], + tuple_delimiter: str, + record_delimiter: str, + ) -> nx.Graph: + """Parse the result string to create an undirected unipartite graph. + + Args: + - results - dict of results from the extraction chain + - tuple_delimiter - delimiter between tuples in an output record, default is '<|>' + - record_delimiter - delimiter between records, default is '##' + Returns: + - output - unipartite graph in graphML format + """ + graph = nx.Graph() + for source_doc_id, extracted_data in results.items(): + records = [r.strip() for r in extracted_data.split(record_delimiter)] + + for record in records: + record = re.sub(r"^\(|\)$", "", record.strip()) + record_attributes = record.split(tuple_delimiter) + + if record_attributes[0] == '"entity"' and len(record_attributes) >= 4: + # add this record as a node in the G + entity_name = clean_str(record_attributes[1].upper()) + entity_type = clean_str(record_attributes[2].upper()) + entity_description = clean_str(record_attributes[3]) + + if entity_name in graph.nodes(): + node = graph.nodes[entity_name] + if self._join_descriptions: + node["description"] = "\n".join( + list({ + *_unpack_descriptions(node), + entity_description, + }) + ) + else: + if len(entity_description) > len(node["description"]): + node["description"] = entity_description + node["source_id"] = ", ".join( + list({ + *_unpack_source_ids(node), + str(source_doc_id), + }) + ) + node["entity_type"] = ( + entity_type if entity_type != "" else node["entity_type"] + ) + else: + graph.add_node( + entity_name, + type=entity_type, + description=entity_description, + source_id=str(source_doc_id), + ) + + if ( + record_attributes[0] == '"relationship"' + and len(record_attributes) >= 5 + ): + # add this record as edge + source = clean_str(record_attributes[1].upper()) + target = clean_str(record_attributes[2].upper()) + edge_description = clean_str(record_attributes[3]) + edge_source_id = clean_str(str(source_doc_id)) + weight = ( + float(record_attributes[-1]) + if isinstance(record_attributes[-1], numbers.Number) + else 1.0 + ) + if source not in graph.nodes(): + graph.add_node( + source, + type="", + description="", + source_id=edge_source_id, + ) + if target not in graph.nodes(): + graph.add_node( + target, + type="", + description="", + source_id=edge_source_id, + ) + if graph.has_edge(source, target): + edge_data = graph.get_edge_data(source, target) + if edge_data is not None: + weight += edge_data["weight"] + if self._join_descriptions: + edge_description = "\n".join( + list({ + *_unpack_descriptions(edge_data), + edge_description, + }) + ) + edge_source_id = ", ".join( + list({ + *_unpack_source_ids(edge_data), + str(source_doc_id), + }) + ) + graph.add_edge( + source, + target, + weight=weight, + description=edge_description, + source_id=edge_source_id, + ) + + return graph + + +def _unpack_descriptions(data: Mapping) -> list[str]: + value = data.get("description", None) + return [] if value is None else value.split("\n") + + +def _unpack_source_ids(data: Mapping) -> list[str]: + value = data.get("source_id", None) + return [] if value is None else value.split(", ") diff --git a/func-app/graphrag/index/graph/extractors/graph/prompts.py b/func-app/graphrag/index/graph/extractors/graph/prompts.py new file mode 100644 index 0000000000..cb1bcc668a --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/graph/prompts.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing prompts definition.""" + +GRAPH_EXTRACTION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity + Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +###################### +-Examples- +###################### +Example 1: +Entity_types: ORGANIZATION,PERSON +Text: +The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. +###################### +Output: +("entity"{tuple_delimiter}CENTRAL INSTITUTION{tuple_delimiter}ORGANIZATION{tuple_delimiter}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) +{record_delimiter} +("entity"{tuple_delimiter}MARTIN SMITH{tuple_delimiter}PERSON{tuple_delimiter}Martin Smith is the chair of the Central Institution) +{record_delimiter} +("entity"{tuple_delimiter}MARKET STRATEGY COMMITTEE{tuple_delimiter}ORGANIZATION{tuple_delimiter}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) +{record_delimiter} +("relationship"{tuple_delimiter}MARTIN SMITH{tuple_delimiter}CENTRAL INSTITUTION{tuple_delimiter}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{tuple_delimiter}9) +{completion_delimiter} + +###################### +Example 2: +Entity_types: ORGANIZATION +Text: +TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform. + +TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. +###################### +Output: +("entity"{tuple_delimiter}TECHGLOBAL{tuple_delimiter}ORGANIZATION{tuple_delimiter}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) +{record_delimiter} +("entity"{tuple_delimiter}VISION HOLDINGS{tuple_delimiter}ORGANIZATION{tuple_delimiter}Vision Holdings is a firm that previously owned TechGlobal) +{record_delimiter} +("relationship"{tuple_delimiter}TECHGLOBAL{tuple_delimiter}VISION HOLDINGS{tuple_delimiter}Vision Holdings formerly owned TechGlobal from 2014 until present{tuple_delimiter}5) +{completion_delimiter} + +###################### +Example 3: +Entity_types: ORGANIZATION,GEO,PERSON +Text: +Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia. + +The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara. + +The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara. + +They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion. + +The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. +###################### +Output: +("entity"{tuple_delimiter}FIRUZABAD{tuple_delimiter}GEO{tuple_delimiter}Firuzabad held Aurelians as hostages) +{record_delimiter} +("entity"{tuple_delimiter}AURELIA{tuple_delimiter}GEO{tuple_delimiter}Country seeking to release hostages) +{record_delimiter} +("entity"{tuple_delimiter}QUINTARA{tuple_delimiter}GEO{tuple_delimiter}Country that negotiated a swap of money in exchange for hostages) +{record_delimiter} +{record_delimiter} +("entity"{tuple_delimiter}TIRUZIA{tuple_delimiter}GEO{tuple_delimiter}Capital of Firuzabad where the Aurelians were being held) +{record_delimiter} +("entity"{tuple_delimiter}KROHAARA{tuple_delimiter}GEO{tuple_delimiter}Capital city in Quintara) +{record_delimiter} +("entity"{tuple_delimiter}CASHION{tuple_delimiter}GEO{tuple_delimiter}Capital city in Aurelia) +{record_delimiter} +("entity"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}PERSON{tuple_delimiter}Aurelian who spent time in Tiruzia's Alhamia Prison) +{record_delimiter} +("entity"{tuple_delimiter}ALHAMIA PRISON{tuple_delimiter}GEO{tuple_delimiter}Prison in Tiruzia) +{record_delimiter} +("entity"{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}PERSON{tuple_delimiter}Aurelian journalist who was held hostage) +{record_delimiter} +("entity"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}PERSON{tuple_delimiter}Bratinas national and environmentalist who was held hostage) +{record_delimiter} +("relationship"{tuple_delimiter}FIRUZABAD{tuple_delimiter}AURELIA{tuple_delimiter}Firuzabad negotiated a hostage exchange with Aurelia{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}QUINTARA{tuple_delimiter}AURELIA{tuple_delimiter}Quintara brokered the hostage exchange between Firuzabad and Aurelia{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}QUINTARA{tuple_delimiter}FIRUZABAD{tuple_delimiter}Quintara brokered the hostage exchange between Firuzabad and Aurelia{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}ALHAMIA PRISON{tuple_delimiter}Samuel Namara was a prisoner at Alhamia prison{tuple_delimiter}8) +{record_delimiter} +("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}FIRUZABAD{tuple_delimiter}Samuel Namara was a hostage in Firuzabad{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}FIRUZABAD{tuple_delimiter}Meggie Tazbah was a hostage in Firuzabad{tuple_delimiter}2) +{record_delimiter} +("relationship"{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}FIRUZABAD{tuple_delimiter}Durke Bataglani was a hostage in Firuzabad{tuple_delimiter}2) +{completion_delimiter} + +###################### +-Real Data- +###################### +Entity_types: {entity_types} +Text: {input_text} +###################### +Output:""" + +CONTINUE_PROMPT = "MANY entities and relationships were missed in the last extraction. Remember to ONLY emit entities that match any of the previously extracted types. Add them below using the same format:\n" +LOOP_PROMPT = "It appears some entities and relationships may have still been missed. Answer YES | NO if there are still entities or relationships that need to be added.\n" diff --git a/func-app/graphrag/index/graph/extractors/summarize/__init__.py b/func-app/graphrag/index/graph/extractors/summarize/__init__.py new file mode 100644 index 0000000000..b4bfe5be87 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/summarize/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine unipartite graph package root.""" + +from .description_summary_extractor import ( + SummarizationResult, + SummarizeExtractor, +) +from .prompts import SUMMARIZE_PROMPT + +__all__ = ["SUMMARIZE_PROMPT", "SummarizationResult", "SummarizeExtractor"] diff --git a/func-app/graphrag/index/graph/extractors/summarize/description_summary_extractor.py b/func-app/graphrag/index/graph/extractors/summarize/description_summary_extractor.py new file mode 100644 index 0000000000..76d77202d3 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/summarize/description_summary_extractor.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'GraphExtractionResult' and 'GraphExtractor' models.""" + +import json +from dataclasses import dataclass + +from graphrag.index.typing import ErrorHandlerFn +from graphrag.index.utils.tokens import num_tokens_from_string +from graphrag.llm import CompletionLLM + +from .prompts import SUMMARIZE_PROMPT + +# Max token size for input prompts +DEFAULT_MAX_INPUT_TOKENS = 4_000 +# Max token count for LLM answers +DEFAULT_MAX_SUMMARY_LENGTH = 500 + + +@dataclass +class SummarizationResult: + """Unipartite graph extraction result class definition.""" + + items: str | tuple[str, str] + description: str + + +class SummarizeExtractor: + """Unipartite graph extractor class definition.""" + + _llm: CompletionLLM + _entity_name_key: str + _input_descriptions_key: str + _summarization_prompt: str + _on_error: ErrorHandlerFn + _max_summary_length: int + _max_input_tokens: int + + def __init__( + self, + llm_invoker: CompletionLLM, + entity_name_key: str | None = None, + input_descriptions_key: str | None = None, + summarization_prompt: str | None = None, + on_error: ErrorHandlerFn | None = None, + max_summary_length: int | None = None, + max_input_tokens: int | None = None, + ): + """Init method definition.""" + # TODO: streamline construction + self._llm = llm_invoker + self._entity_name_key = entity_name_key or "entity_name" + self._input_descriptions_key = input_descriptions_key or "description_list" + + self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT + self._on_error = on_error or (lambda _e, _s, _d: None) + self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH + self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS + + async def __call__( + self, + items: str | tuple[str, str], + descriptions: list[str], + ) -> SummarizationResult: + """Call method definition.""" + result = "" + if len(descriptions) == 0: + result = "" + if len(descriptions) == 1: + result = descriptions[0] + else: + result = await self._summarize_descriptions(items, descriptions) + + return SummarizationResult( + items=items, + description=result or "", + ) + + async def _summarize_descriptions( + self, items: str | tuple[str, str], descriptions: list[str] + ) -> str: + """Summarize descriptions into a single description.""" + sorted_items = sorted(items) if isinstance(items, list) else items + + # Safety check, should always be a list + if not isinstance(descriptions, list): + descriptions = [descriptions] + + # Iterate over descriptions, adding all until the max input tokens is reached + usable_tokens = self._max_input_tokens - num_tokens_from_string( + self._summarization_prompt + ) + descriptions_collected = [] + result = "" + + for i, description in enumerate(descriptions): + usable_tokens -= num_tokens_from_string(description) + descriptions_collected.append(description) + + # If buffer is full, or all descriptions have been added, summarize + if (usable_tokens < 0 and len(descriptions_collected) > 1) or ( + i == len(descriptions) - 1 + ): + # Calculate result (final or partial) + result = await self._summarize_descriptions_with_llm( + sorted_items, descriptions_collected + ) + + # If we go for another loop, reset values to new + if i != len(descriptions) - 1: + descriptions_collected = [result] + usable_tokens = ( + self._max_input_tokens + - num_tokens_from_string(self._summarization_prompt) + - num_tokens_from_string(result) + ) + + return result + + async def _summarize_descriptions_with_llm( + self, items: str | tuple[str, str] | list[str], descriptions: list[str] + ): + """Summarize descriptions using the LLM.""" + response = await self._llm( + self._summarization_prompt, + name="summarize", + variables={ + self._entity_name_key: json.dumps(items), + self._input_descriptions_key: json.dumps(sorted(descriptions)), + }, + model_parameters={"max_tokens": self._max_summary_length}, + ) + # Calculate result + return str(response.output) diff --git a/func-app/graphrag/index/graph/extractors/summarize/prompts.py b/func-app/graphrag/index/graph/extractors/summarize/prompts.py new file mode 100644 index 0000000000..90e4434ee8 --- /dev/null +++ b/func-app/graphrag/index/graph/extractors/summarize/prompts.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing prompts definition.""" + +SUMMARIZE_PROMPT = """ +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Entities: {entity_name} +Description List: {description_list} +####### +Output: +""" diff --git a/func-app/graphrag/index/graph/utils/__init__.py b/func-app/graphrag/index/graph/utils/__init__.py new file mode 100644 index 0000000000..6d4479283a --- /dev/null +++ b/func-app/graphrag/index/graph/utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph utils package root.""" + +from .normalize_node_names import normalize_node_names +from .stable_lcc import stable_largest_connected_component + +__all__ = ["normalize_node_names", "stable_largest_connected_component"] diff --git a/func-app/graphrag/index/graph/utils/normalize_node_names.py b/func-app/graphrag/index/graph/utils/normalize_node_names.py new file mode 100644 index 0000000000..bcc874a927 --- /dev/null +++ b/func-app/graphrag/index/graph/utils/normalize_node_names.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing normalize_node_names method definition.""" + +import html + +import networkx as nx + + +def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph: + """Normalize node names.""" + node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore + return nx.relabel_nodes(graph, node_mapping) diff --git a/func-app/graphrag/index/graph/utils/stable_lcc.py b/func-app/graphrag/index/graph/utils/stable_lcc.py new file mode 100644 index 0000000000..7d602a6ba7 --- /dev/null +++ b/func-app/graphrag/index/graph/utils/stable_lcc.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module for producing a stable largest connected component, i.e. same input graph == same output lcc.""" + +from typing import Any, cast + +import networkx as nx +from graspologic.utils import largest_connected_component + +from .normalize_node_names import normalize_node_names + + +def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: + """Return the largest connected component of the graph, with nodes and edges sorted in a stable way.""" + graph = graph.copy() + graph = cast(nx.Graph, largest_connected_component(graph)) + graph = normalize_node_names(graph) + return _stabilize_graph(graph) + + +def _stabilize_graph(graph: nx.Graph) -> nx.Graph: + """Ensure an undirected graph with the same relationships will always be read the same way.""" + fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() + + sorted_nodes = graph.nodes(data=True) + sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) + + fixed_graph.add_nodes_from(sorted_nodes) + edges = list(graph.edges(data=True)) + + # If the graph is undirected, we create the edges in a stable way, so we get the same results + # for example: + # A -> B + # in graph theory is the same as + # B -> A + # in an undirected graph + # however, this can lead to downstream issues because sometimes + # consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A] + # but they base some of their logic on the order of the nodes, so the order ends up being important + # so we sort the nodes in the edge in a stable way, so that we always get the same order + if not graph.is_directed(): + + def _sort_source_target(edge): + source, target, edge_data = edge + if source > target: + temp = source + source = target + target = temp + return source, target, edge_data + + edges = [_sort_source_target(edge) for edge in edges] + + def _get_edge_key(source: Any, target: Any) -> str: + return f"{source} -> {target}" + + edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) + + fixed_graph.add_edges_from(edges) + return fixed_graph diff --git a/func-app/graphrag/index/graph/visualization/__init__.py b/func-app/graphrag/index/graph/visualization/__init__.py new file mode 100644 index 0000000000..f7780e4e9c --- /dev/null +++ b/func-app/graphrag/index/graph/visualization/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph visualization package root.""" + +from .compute_umap_positions import compute_umap_positions, get_zero_positions +from .typing import GraphLayout, NodePosition + +__all__ = [ + "GraphLayout", + "NodePosition", + "compute_umap_positions", + "get_zero_positions", +] diff --git a/func-app/graphrag/index/graph/visualization/compute_umap_positions.py b/func-app/graphrag/index/graph/visualization/compute_umap_positions.py new file mode 100644 index 0000000000..569b7b309d --- /dev/null +++ b/func-app/graphrag/index/graph/visualization/compute_umap_positions.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing compute_umap_positions and visualize_embedding method definition.""" + +import graspologic as gc +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np +import umap + +from .typing import NodePosition + + +def get_zero_positions( + node_labels: list[str], + node_categories: list[int] | None = None, + node_sizes: list[int] | None = None, + three_d: bool | None = False, +) -> list[NodePosition]: + """Project embedding vectors down to 2D/3D using UMAP.""" + embedding_position_data: list[NodePosition] = [] + for index, node_name in enumerate(node_labels): + node_category = 1 if node_categories is None else node_categories[index] + node_size = 1 if node_sizes is None else node_sizes[index] + + if not three_d: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=0, + y=0, + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + else: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=0, + y=0, + z=0, + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + return embedding_position_data + + +def compute_umap_positions( + embedding_vectors: np.ndarray, + node_labels: list[str], + node_categories: list[int] | None = None, + node_sizes: list[int] | None = None, + min_dist: float = 0.75, + n_neighbors: int = 25, + spread: int = 1, + metric: str = "euclidean", + n_components: int = 2, + random_state: int = 86, +) -> list[NodePosition]: + """Project embedding vectors down to 2D/3D using UMAP.""" + embedding_positions = umap.UMAP( + min_dist=min_dist, + n_neighbors=n_neighbors, + spread=spread, + n_components=n_components, + metric=metric, + random_state=random_state, + ).fit_transform(embedding_vectors) + + embedding_position_data: list[NodePosition] = [] + for index, node_name in enumerate(node_labels): + node_points = embedding_positions[index] # type: ignore + node_category = 1 if node_categories is None else node_categories[index] + node_size = 1 if node_sizes is None else node_sizes[index] + + if len(node_points) == 2: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=float(node_points[0]), + y=float(node_points[1]), + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + else: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=float(node_points[0]), + y=float(node_points[1]), + z=float(node_points[2]), + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + return embedding_position_data + + +def visualize_embedding( + graph, + umap_positions: list[dict], +): + """Project embedding down to 2D using UMAP and visualize.""" + # rendering + plt.clf() + figure = plt.gcf() + ax = plt.gca() + + ax.set_axis_off() + figure.set_size_inches(10, 10) + figure.set_dpi(400) + + node_position_dict = { + (str)(position["label"]): (position["x"], position["y"]) + for position in umap_positions + } + node_category_dict = { + (str)(position["label"]): position["category"] for position in umap_positions + } + node_sizes = [position["size"] for position in umap_positions] + node_colors = gc.layouts.categorical_colors(node_category_dict) # type: ignore + + vertices = [] + node_color_list = [] + for node in node_position_dict: + vertices.append(node) + node_color_list.append(node_colors[node]) + + nx.draw_networkx_nodes( + graph, + pos=node_position_dict, + nodelist=vertices, + node_color=node_color_list, # type: ignore + alpha=1.0, + linewidths=0.01, + node_size=node_sizes, # type: ignore + node_shape="o", + ax=ax, + ) + plt.show() diff --git a/func-app/graphrag/index/graph/visualization/typing.py b/func-app/graphrag/index/graph/visualization/typing.py new file mode 100644 index 0000000000..ae46afa928 --- /dev/null +++ b/func-app/graphrag/index/graph/visualization/typing.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +# Use this for now instead of a wrapper +"""A module containing 'NodePosition' model.""" + +from dataclasses import dataclass + + +@dataclass +class NodePosition: + """Node position class definition.""" + + label: str + cluster: str + size: float + + x: float + y: float + z: float | None = None + + def to_pandas(self) -> tuple[str, float, float, str, float]: + """To pandas method definition.""" + return self.label, self.x, self.y, self.cluster, self.size + + +GraphLayout = list[NodePosition] diff --git a/func-app/graphrag/index/init_content.py b/func-app/graphrag/index/init_content.py new file mode 100644 index 0000000000..8f5982f8f7 --- /dev/null +++ b/func-app/graphrag/index/init_content.py @@ -0,0 +1,176 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Content for the init CLI command.""" + +import graphrag.config.defaults as defs + +INIT_YAML = f""" +encoding_model: cl100k_base +skip_workflows: [] +llm: + api_key: ${{GRAPHRAG_API_KEY}} + type: {defs.LLM_TYPE.value} # or azure_openai_chat + model: {defs.LLM_MODEL} + model_supports_json: true # recommended if this is available for your model. + # max_tokens: {defs.LLM_MAX_TOKENS} + # request_timeout: {defs.LLM_REQUEST_TIMEOUT} + # api_base: https://.openai.azure.com + # api_version: 2024-02-15-preview + # organization: + # deployment_name: + # tokens_per_minute: 150_000 # set a leaky bucket throttle + # requests_per_minute: 10_000 # set a leaky bucket throttle + # max_retries: {defs.LLM_MAX_RETRIES} + # max_retry_wait: {defs.LLM_MAX_RETRY_WAIT} + # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times + # concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # the number of parallel inflight requests that may be made + # temperature: {defs.LLM_TEMPERATURE} # temperature for sampling + # top_p: {defs.LLM_TOP_P} # top-p sampling + # n: {defs.LLM_N} # Number of completions to generate + +parallelization: + stagger: {defs.PARALLELIZATION_STAGGER} + # num_threads: {defs.PARALLELIZATION_NUM_THREADS} # the number of threads to use for parallel processing + +async_mode: {defs.ASYNC_MODE.value} # or asyncio + +embeddings: + ## parallelization: override the global parallelization settings for embeddings + async_mode: {defs.ASYNC_MODE.value} # or asyncio + llm: + api_key: ${{GRAPHRAG_API_KEY}} + type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding + model: {defs.EMBEDDING_MODEL} + # api_base: https://.openai.azure.com + # api_version: 2024-02-15-preview + # organization: + # deployment_name: + # tokens_per_minute: 150_000 # set a leaky bucket throttle + # requests_per_minute: 10_000 # set a leaky bucket throttle + # max_retries: {defs.LLM_MAX_RETRIES} + # max_retry_wait: {defs.LLM_MAX_RETRY_WAIT} + # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times + # concurrent_requests: {defs.LLM_CONCURRENT_REQUESTS} # the number of parallel inflight requests that may be made + # batch_size: {defs.EMBEDDING_BATCH_SIZE} # the number of documents to send in a single request + # batch_max_tokens: {defs.EMBEDDING_BATCH_MAX_TOKENS} # the maximum number of tokens to send in a single request + # target: {defs.EMBEDDING_TARGET.value} # or optional + + + +chunks: + size: {defs.CHUNK_SIZE} + overlap: {defs.CHUNK_OVERLAP} + group_by_columns: [{",".join(defs.CHUNK_GROUP_BY_COLUMNS)}] # by default, we don't allow chunks to cross documents + +input: + type: {defs.INPUT_TYPE.value} # or blob + file_type: {defs.INPUT_FILE_TYPE.value} # or csv + base_dir: "{defs.INPUT_BASE_DIR}" + file_encoding: {defs.INPUT_FILE_ENCODING} + file_pattern: ".*\\\\.txt$" + +cache: + type: {defs.CACHE_TYPE.value} # or blob + base_dir: "{defs.CACHE_BASE_DIR}" + # connection_string: + # container_name: + +storage: + type: {defs.STORAGE_TYPE.value} # or blob + base_dir: "{defs.STORAGE_BASE_DIR}" + # connection_string: + # container_name: + +reporting: + type: {defs.REPORTING_TYPE.value} # or console, blob + base_dir: "{defs.REPORTING_BASE_DIR}" + # connection_string: + # container_name: + +entity_extraction: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/entity_extraction.txt" + entity_types: [{",".join(defs.ENTITY_EXTRACTION_ENTITY_TYPES)}] + max_gleanings: {defs.ENTITY_EXTRACTION_MAX_GLEANINGS} + +summarize_descriptions: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/summarize_descriptions.txt" + max_length: {defs.SUMMARIZE_DESCRIPTIONS_MAX_LENGTH} + +claim_extraction: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + # enabled: true + prompt: "prompts/claim_extraction.txt" + description: "{defs.CLAIM_DESCRIPTION}" + max_gleanings: {defs.CLAIM_MAX_GLEANINGS} + +community_reports: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/community_report.txt" + max_length: {defs.COMMUNITY_REPORT_MAX_LENGTH} + max_input_length: {defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH} + +cluster_graph: + max_cluster_size: {defs.MAX_CLUSTER_SIZE} + +embed_graph: + enabled: false # if true, will generate node2vec embeddings for nodes + # num_walks: {defs.NODE2VEC_NUM_WALKS} + # walk_length: {defs.NODE2VEC_WALK_LENGTH} + # window_size: {defs.NODE2VEC_WINDOW_SIZE} + # iterations: {defs.NODE2VEC_ITERATIONS} + # random_seed: {defs.NODE2VEC_RANDOM_SEED} + +umap: + enabled: false # if true, will generate UMAP embeddings for nodes + +snapshots: + graphml: false + raw_entities: false + top_level_nodes: false + +local_search: + # text_unit_prop: {defs.LOCAL_SEARCH_TEXT_UNIT_PROP} + # community_prop: {defs.LOCAL_SEARCH_COMMUNITY_PROP} + # conversation_history_max_turns: {defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS} + # top_k_mapped_entities: {defs.LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES} + # top_k_relationships: {defs.LOCAL_SEARCH_TOP_K_RELATIONSHIPS} + # llm_temperature: {defs.LOCAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling + # llm_top_p: {defs.LOCAL_SEARCH_LLM_TOP_P} # top-p sampling + # llm_n: {defs.LOCAL_SEARCH_LLM_N} # Number of completions to generate + # max_tokens: {defs.LOCAL_SEARCH_MAX_TOKENS} + +global_search: + # llm_temperature: {defs.GLOBAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling + # llm_top_p: {defs.GLOBAL_SEARCH_LLM_TOP_P} # top-p sampling + # llm_n: {defs.GLOBAL_SEARCH_LLM_N} # Number of completions to generate + # max_tokens: {defs.GLOBAL_SEARCH_MAX_TOKENS} + # data_max_tokens: {defs.GLOBAL_SEARCH_DATA_MAX_TOKENS} + # map_max_tokens: {defs.GLOBAL_SEARCH_MAP_MAX_TOKENS} + # reduce_max_tokens: {defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS} + # concurrency: {defs.GLOBAL_SEARCH_CONCURRENCY} + +query_context: + # Files: [] # list of files in context to run query + +graphdb: + account_name: '' + account_key: '' + username: '' + enabled: false + cosmos_url: '' + gremlin_url: '' +""" + +INIT_DOTENV = """ +GRAPHRAG_API_KEY= +""" diff --git a/func-app/graphrag/index/input/__init__.py b/func-app/graphrag/index/input/__init__.py new file mode 100644 index 0000000000..91421867de --- /dev/null +++ b/func-app/graphrag/index/input/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine input package root.""" + +from .load_input import load_input + +__all__ = ["load_input"] diff --git a/func-app/graphrag/index/input/csv.py b/func-app/graphrag/index/input/csv.py new file mode 100644 index 0000000000..04f43ddda0 --- /dev/null +++ b/func-app/graphrag/index/input/csv.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load method definition.""" + +import logging +import re +from io import BytesIO +from typing import cast + +import pandas as pd + +from graphrag.index.config import PipelineCSVInputConfig, PipelineInputConfig +from graphrag.common.progress import ProgressReporter +from graphrag.common.storage import PipelineStorage +from graphrag.index.utils import gen_md5_hash + +log = logging.getLogger(__name__) + +DEFAULT_FILE_PATTERN = re.compile(r"(?P[^\\/]).csv$") + +input_type = "csv" + + +async def load( + config: PipelineInputConfig, + progress: ProgressReporter | None, + storage: PipelineStorage, +) -> pd.DataFrame: + """Load csv inputs from a directory.""" + csv_config = cast(PipelineCSVInputConfig, config) + log.info("Loading csv files from %s", csv_config.base_dir) + + async def load_file(path: str, group: dict | None) -> pd.DataFrame: + if group is None: + group = {} + buffer = BytesIO(await storage.get(path, as_bytes=True)) + data = pd.read_csv(buffer, encoding=config.encoding or "latin-1") + additional_keys = group.keys() + if len(additional_keys) > 0: + data[[*additional_keys]] = data.apply( + lambda _row: pd.Series([group[key] for key in additional_keys]), axis=1 + ) + if "id" not in data.columns: + data["id"] = data.apply(lambda x: gen_md5_hash(x, x.keys()), axis=1) + if csv_config.source_column is not None and "source" not in data.columns: + if csv_config.source_column not in data.columns: + log.warning( + "source_column %s not found in csv file %s", + csv_config.source_column, + path, + ) + else: + data["source"] = data.apply( + lambda x: x[csv_config.source_column], axis=1 + ) + if csv_config.text_column is not None and "text" not in data.columns: + if csv_config.text_column not in data.columns: + log.warning( + "text_column %s not found in csv file %s", + csv_config.text_column, + path, + ) + else: + data["text"] = data.apply(lambda x: x[csv_config.text_column], axis=1) + if csv_config.title_column is not None and "title" not in data.columns: + if csv_config.title_column not in data.columns: + log.warning( + "title_column %s not found in csv file %s", + csv_config.title_column, + path, + ) + else: + data["title"] = data.apply(lambda x: x[csv_config.title_column], axis=1) + + if csv_config.timestamp_column is not None: + fmt = csv_config.timestamp_format + if fmt is None: + msg = "Must specify timestamp_format if timestamp_column is specified" + raise ValueError(msg) + + if csv_config.timestamp_column not in data.columns: + log.warning( + "timestamp_column %s not found in csv file %s", + csv_config.timestamp_column, + path, + ) + else: + data["timestamp"] = pd.to_datetime( + data[csv_config.timestamp_column], format=fmt + ) + + # TODO: Theres probably a less gross way to do this + if "year" not in data.columns: + data["year"] = data.apply(lambda x: x["timestamp"].year, axis=1) + if "month" not in data.columns: + data["month"] = data.apply(lambda x: x["timestamp"].month, axis=1) + if "day" not in data.columns: + data["day"] = data.apply(lambda x: x["timestamp"].day, axis=1) + if "hour" not in data.columns: + data["hour"] = data.apply(lambda x: x["timestamp"].hour, axis=1) + if "minute" not in data.columns: + data["minute"] = data.apply(lambda x: x["timestamp"].minute, axis=1) + if "second" not in data.columns: + data["second"] = data.apply(lambda x: x["timestamp"].second, axis=1) + + return data + + file_pattern = ( + re.compile(config.file_pattern) + if config.file_pattern is not None + else DEFAULT_FILE_PATTERN + ) + files = list( + storage.find( + file_pattern, + progress=progress, + file_filter=config.file_filter, + ) + ) + + if len(files) == 0: + msg = f"No CSV files found in {config.base_dir}" + raise ValueError(msg) + + files_loaded = [] + + for file, group in files: + try: + files_loaded.append(await load_file(file, group)) + except Exception: # noqa: BLE001 (catching Exception is fine here) + log.warning("Warning! Error loading csv file %s. Skipping...", file) + + log.info("Found %d csv files, loading %d", len(files), len(files_loaded)) + result = pd.concat(files_loaded) + total_files_log = f"Total number of unfiltered csv rows: {len(result)}" + log.info(total_files_log) + return result diff --git a/func-app/graphrag/index/input/load_input.py b/func-app/graphrag/index/input/load_input.py new file mode 100644 index 0000000000..14dd90635e --- /dev/null +++ b/func-app/graphrag/index/input/load_input.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load_input method definition.""" + +import logging +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import cast + +import pandas as pd + +from graphrag.config import InputConfig, InputType +from graphrag.index.config import PipelineInputConfig +from graphrag.common.progress import NullProgressReporter, ProgressReporter +from graphrag.common.storage import ( + BlobPipelineStorage, + FilePipelineStorage, +) + +from .csv import input_type as csv +from .csv import load as load_csv +from .text import input_type as text +from .text import load as load_text + +log = logging.getLogger(__name__) +loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = { + text: load_text, + csv: load_csv, +} + + +async def load_input( + config: PipelineInputConfig | InputConfig, + progress_reporter: ProgressReporter | None = None, + root_dir: str | None = None, +) -> pd.DataFrame: + """Load the input data for a pipeline.""" + root_dir = root_dir or "" + log.info(f"loading input from root_dir {root_dir}") + progress_reporter = progress_reporter or NullProgressReporter() + + if config is None: + msg = "No input specified!" + raise ValueError(msg) + + match config.type: + case InputType.blob: + log.info("using blob storage input") + if config.container_name is None: + msg = "Container name required for blob storage" + raise ValueError(msg) + if ( + config.connection_string is None + and config.storage_account_blob_url is None + ): + msg = "Connection string or storage account blob url required for blob storage" + raise ValueError(msg) + storage = BlobPipelineStorage( + connection_string=config.connection_string, + storage_account_blob_url=config.storage_account_blob_url, + container_name=config.container_name, + ) + case InputType.file: + log.info("using file storage for input") + storage = FilePipelineStorage( + root_dir=str(Path(root_dir) / (config.base_dir or "")) + ) + case _: + log.info("using file storage for input") + storage = FilePipelineStorage( + root_dir=str(Path(root_dir) / (config.base_dir or "")) + ) + + if config.file_type in loaders: + progress = progress_reporter.child( + f"Loading Input ({config.file_type})", transient=False + ) + loader = loaders[config.file_type] + results = await loader(config, progress, storage) + return cast(pd.DataFrame, results) + + msg = f"Unknown input type {config.file_type}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/input/text.py b/func-app/graphrag/index/input/text.py new file mode 100644 index 0000000000..3a3e15cbd9 --- /dev/null +++ b/func-app/graphrag/index/input/text.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load method definition.""" + +import logging +import re +from pathlib import Path +from typing import Any + +import pandas as pd + +from graphrag.index.config import PipelineInputConfig +from graphrag.common.progress import ProgressReporter +from graphrag.common.storage import PipelineStorage +from graphrag.index.utils import gen_md5_hash + +DEFAULT_FILE_PATTERN = re.compile( + r".*[\\/](?P[^\\/]+)[\\/](?P\d{4})-(?P\d{2})-(?P\d{2})_(?P[^_]+)_\d+\.txt" +) +input_type = "text" +log = logging.getLogger(__name__) + + +async def load( + config: PipelineInputConfig, + progress: ProgressReporter | None, + storage: PipelineStorage, +) -> pd.DataFrame: + """Load text inputs from a directory.""" + + async def load_file( + path: str, group: dict | None = None, _encoding: str = "utf-8" #what is group here, can be used as context? + ) -> dict[str, Any]: + if group is None: + group = {} + text = await storage.get(path, encoding="utf-8") + new_item = {**group, "text": text} + new_item["id"] = gen_md5_hash(new_item, new_item.keys()) + new_item["title"] = str(Path(path).name) + return new_item + base_dir = config.base_dir + if config.type == "file": + #base dir is already being added to root dir in case of type file. + base_dir = None + files = list( + storage.find( + re.compile(config.file_pattern), + progress=progress, + file_filter=config.file_filter, + base_dir=base_dir + ) + ) + + if len(files) == 0: + msg = f"No text files found in {config.base_dir}" + raise ValueError(msg) + + found_files = f"found text files from {config.base_dir}, found {files}" + log.info(found_files) + + files_loaded = [] + + for file, group in files: + try: + files_loaded.append(await load_file(file, group)) + except Exception: # noqa: BLE001 (catching Exception is fine here) + log.warning("Warning! Error loading file %s. Skipping...", file) + + log.info("Found %d files, loading %d", len(files), len(files_loaded)) + + return pd.DataFrame(files_loaded) diff --git a/func-app/graphrag/index/llm/__init__.py b/func-app/graphrag/index/llm/__init__.py new file mode 100644 index 0000000000..008ef07ccd --- /dev/null +++ b/func-app/graphrag/index/llm/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine LLM package root.""" + +from .load_llm import load_llm, load_llm_embeddings +from .types import TextListSplitter, TextSplitter + +__all__ = [ + "TextListSplitter", + "TextSplitter", + "load_llm", + "load_llm_embeddings", +] diff --git a/func-app/graphrag/index/llm/load_llm.py b/func-app/graphrag/index/llm/load_llm.py new file mode 100644 index 0000000000..264229c887 --- /dev/null +++ b/func-app/graphrag/index/llm/load_llm.py @@ -0,0 +1,313 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Load llm utilities.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING, Any + +from graphrag.config.enums import LLMType +from graphrag.llm import ( + CompletionLLM, + EmbeddingLLM, + LLMCache, + LLMLimiter, + MockCompletionLLM, + OpenAIConfiguration, + create_openai_chat_llm, + create_openai_client, + create_openai_completion_llm, + create_openai_embedding_llm, + create_tpm_rpm_limiters, +) + +if TYPE_CHECKING: + from datashaper import VerbCallbacks + + from graphrag.index.cache import PipelineCache + from graphrag.index.typing import ErrorHandlerFn + +log = logging.getLogger(__name__) + +_semaphores: dict[str, asyncio.Semaphore] = {} +_rate_limiters: dict[str, LLMLimiter] = {} + + +def load_llm( + name: str, + llm_type: LLMType, + callbacks: VerbCallbacks, + cache: PipelineCache | None, + llm_config: dict[str, Any] | None = None, + chat_only=False, +) -> CompletionLLM: + """Load the LLM for the entity extraction chain.""" + on_error = _create_error_handler(callbacks) + + if llm_type in loaders: + if chat_only and not loaders[llm_type]["chat"]: + msg = f"LLM type {llm_type} does not support chat" + raise ValueError(msg) + if cache is not None: + cache = cache.child(name) + + loader = loaders[llm_type] + return loader["load"](on_error, cache, llm_config or {}) + + msg = f"Unknown LLM type {llm_type}" + raise ValueError(msg) + + +def load_llm_embeddings( + name: str, + llm_type: LLMType, + callbacks: VerbCallbacks, + cache: PipelineCache | None, + llm_config: dict[str, Any] | None = None, + chat_only=False, +) -> EmbeddingLLM: + """Load the LLM for the entity extraction chain.""" + on_error = _create_error_handler(callbacks) + if llm_type in loaders: + if chat_only and not loaders[llm_type]["chat"]: + msg = f"LLM type {llm_type} does not support chat" + raise ValueError(msg) + if cache is not None: + cache = cache.child(name) + + return loaders[llm_type]["load"](on_error, cache, llm_config or {}) + + msg = f"Unknown LLM type {llm_type}" + raise ValueError(msg) + + +def _create_error_handler(callbacks: VerbCallbacks) -> ErrorHandlerFn: + def on_error( + error: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ) -> None: + callbacks.error("Error Invoking LLM", error, stack, details) + + return on_error + + +def _load_openai_completion_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], + azure=False, +): + return _create_openai_completion_llm( + OpenAIConfiguration({ + **_get_base_config(config), + "model": config.get("model", "gpt-4-turbo-preview"), + "deployment_name": config.get("deployment_name"), + "temperature": config.get("temperature", 0.0), + "frequency_penalty": config.get("frequency_penalty", 0), + "presence_penalty": config.get("presence_penalty", 0), + "top_p": config.get("top_p", 1), + "max_tokens": config.get("max_tokens", 4000), + "n": config.get("n"), + }), + on_error, + cache, + azure, + ) + + +def _load_openai_chat_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], + azure=False, +): + return _create_openai_chat_llm( + OpenAIConfiguration({ + # Set default values + **_get_base_config(config), + "model": config.get("model", "gpt-4-turbo-preview"), + "deployment_name": config.get("deployment_name"), + "temperature": config.get("temperature", 0.0), + "frequency_penalty": config.get("frequency_penalty", 0), + "presence_penalty": config.get("presence_penalty", 0), + "top_p": config.get("top_p", 1), + "max_tokens": config.get("max_tokens"), + "n": config.get("n"), + }), + on_error, + cache, + azure, + ) + + +def _load_openai_embeddings_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], + azure=False, +): + # TODO: Inject Cache + return _create_openai_embeddings_llm( + OpenAIConfiguration({ + **_get_base_config(config), + "model": config.get( + "embeddings_model", config.get("model", "text-embedding-3-small") + ), + "deployment_name": config.get("deployment_name"), + }), + on_error, + cache, + azure, + ) + + +def _load_azure_openai_completion_llm( + on_error: ErrorHandlerFn, cache: LLMCache, config: dict[str, Any] +): + return _load_openai_completion_llm(on_error, cache, config, True) + + +def _load_azure_openai_chat_llm( + on_error: ErrorHandlerFn, cache: LLMCache, config: dict[str, Any] +): + return _load_openai_chat_llm(on_error, cache, config, True) + + +def _load_azure_openai_embeddings_llm( + on_error: ErrorHandlerFn, cache: LLMCache, config: dict[str, Any] +): + return _load_openai_embeddings_llm(on_error, cache, config, True) + + +def _get_base_config(config: dict[str, Any]) -> dict[str, Any]: + api_key = config.get("api_key") + + return { + # Pass in all parameterized values + **config, + # Set default values + "api_key": api_key, + "api_base": config.get("api_base"), + "api_version": config.get("api_version"), + "organization": config.get("organization"), + "proxy": config.get("proxy"), + "max_retries": config.get("max_retries", 10), + "request_timeout": config.get("request_timeout", 60.0), + "model_supports_json": config.get("model_supports_json"), + "concurrent_requests": config.get("concurrent_requests", 4), + "encoding_model": config.get("encoding_model", "cl100k_base"), + "cognitive_services_endpoint": config.get("cognitive_services_endpoint"), + } + + +def _load_static_response( + _on_error: ErrorHandlerFn, _cache: PipelineCache, config: dict[str, Any] +) -> CompletionLLM: + return MockCompletionLLM(config.get("responses", [])) + + +loaders = { + LLMType.OpenAI: { + "load": _load_openai_completion_llm, + "chat": False, + }, + LLMType.AzureOpenAI: { + "load": _load_azure_openai_completion_llm, + "chat": False, + }, + LLMType.OpenAIChat: { + "load": _load_openai_chat_llm, + "chat": True, + }, + LLMType.AzureOpenAIChat: { + "load": _load_azure_openai_chat_llm, + "chat": True, + }, + LLMType.OpenAIEmbedding: { + "load": _load_openai_embeddings_llm, + "chat": False, + }, + LLMType.AzureOpenAIEmbedding: { + "load": _load_azure_openai_embeddings_llm, + "chat": False, + }, + LLMType.StaticResponse: { + "load": _load_static_response, + "chat": False, + }, +} + + +def _create_openai_chat_llm( + configuration: OpenAIConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, + azure=False, +) -> CompletionLLM: + """Create an openAI chat llm.""" + client = create_openai_client(configuration=configuration, azure=azure) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_openai_chat_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_openai_completion_llm( + configuration: OpenAIConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, + azure=False, +) -> CompletionLLM: + """Create an openAI completion llm.""" + client = create_openai_client(configuration=configuration, azure=azure) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_openai_completion_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_openai_embeddings_llm( + configuration: OpenAIConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, + azure=False, +) -> EmbeddingLLM: + """Create an openAI embeddings llm.""" + client = create_openai_client(configuration=configuration, azure=azure) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_openai_embedding_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter: + limit_name = configuration.model or configuration.deployment_name or "default" + if limit_name not in _rate_limiters: + tpm = configuration.tokens_per_minute + rpm = configuration.requests_per_minute + log.info("create TPM/RPM limiter for %s: TPM=%s, RPM=%s", limit_name, tpm, rpm) + _rate_limiters[limit_name] = create_tpm_rpm_limiters(configuration) + return _rate_limiters[limit_name] + + +def _create_semaphore(configuration: OpenAIConfiguration) -> asyncio.Semaphore | None: + limit_name = configuration.model or configuration.deployment_name or "default" + concurrency = configuration.concurrent_requests + + # bypass the semaphore if concurrency is zero + if not concurrency: + log.info("no concurrency limiter for %s", limit_name) + return None + + if limit_name not in _semaphores: + log.info("create concurrency limiter for %s: %s", limit_name, concurrency) + _semaphores[limit_name] = asyncio.Semaphore(concurrency) + + return _semaphores[limit_name] diff --git a/func-app/graphrag/index/llm/types.py b/func-app/graphrag/index/llm/types.py new file mode 100644 index 0000000000..73c47737cb --- /dev/null +++ b/func-app/graphrag/index/llm/types.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the 'LLMtype' model.""" + +from collections.abc import Callable +from typing import TypeAlias + +TextSplitter: TypeAlias = Callable[[str], list[str]] +TextListSplitter: TypeAlias = Callable[[list[str]], list[str]] diff --git a/func-app/graphrag/index/load_pipeline_config.py b/func-app/graphrag/index/load_pipeline_config.py new file mode 100644 index 0000000000..7488c8c4fe --- /dev/null +++ b/func-app/graphrag/index/load_pipeline_config.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing read_dotenv, load_pipeline_config, _parse_yaml and _create_include_constructor methods definition.""" + +import json +from pathlib import Path +import logging +import yaml +from pyaml_env import parse_config as parse_config_with_env + +from graphrag.config import create_graphrag_config, read_dotenv +from graphrag.index.config import PipelineConfig + +from .create_pipeline_config import create_pipeline_config + +log = logging.getLogger(__name__) +def load_pipeline_config(config_or_path: str | PipelineConfig) -> PipelineConfig: + """Load a pipeline config from a file path or a config object.""" + if isinstance(config_or_path, PipelineConfig): + log.info(f"PipelineConfig Instance Type") + config = config_or_path + elif config_or_path == "default": + config = create_pipeline_config(create_graphrag_config(root_dir=".")) + else: + # Is there a .env file in the same directory as the config? + read_dotenv(str(Path(config_or_path).parent)) + + if config_or_path.endswith(".json"): + with Path(config_or_path).open("rb") as f: + config = json.loads(f.read().decode(encoding="utf-8", errors="strict")) + elif config_or_path.endswith((".yml", ".yaml")): + config = _parse_yaml(config_or_path) + else: + msg = f"Invalid config file type: {config_or_path}" + raise ValueError(msg) + + config = PipelineConfig.model_validate(config) + if not config.root_dir: + config.root_dir = str(Path(config_or_path).parent.resolve()) + + if config.extends is not None: + if isinstance(config.extends, str): + config.extends = [config.extends] + for extended_config in config.extends: + extended_config = load_pipeline_config(extended_config) + merged_config = { + **json.loads(extended_config.model_dump_json()), + **json.loads(config.model_dump_json(exclude_unset=True)), + } + config = PipelineConfig.model_validate(merged_config) + + return config + + +def _parse_yaml(path: str): + """Parse a yaml file, with support for !include directives.""" + # I don't like that this is static + loader_class = yaml.SafeLoader + + # Add !include constructor if not already present. + if "!include" not in loader_class.yaml_constructors: + loader_class.add_constructor("!include", _create_include_constructor()) + + return parse_config_with_env(path, loader=loader_class, default_value="") + + +def _create_include_constructor(): + """Create a constructor for !include directives.""" + + def handle_include(loader: yaml.Loader, node: yaml.Node): + """Include file referenced at node.""" + filename = str(Path(loader.name).parent / node.value) + if filename.endswith((".yml", ".yaml")): + return _parse_yaml(filename) + + with Path(filename).open("rb") as f: + return f.read().decode(encoding="utf-8", errors="strict") + + return handle_include diff --git a/func-app/graphrag/index/py.typed b/func-app/graphrag/index/py.typed new file mode 100644 index 0000000000..f4bd298955 --- /dev/null +++ b/func-app/graphrag/index/py.typed @@ -0,0 +1,2 @@ +# This package supports type hinting, +# see https://www.python.org/dev/peps/pep-0561/#packaging-type-information \ No newline at end of file diff --git a/func-app/graphrag/index/reporting/__init__.py b/func-app/graphrag/index/reporting/__init__.py new file mode 100644 index 0000000000..697d4fc51f --- /dev/null +++ b/func-app/graphrag/index/reporting/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Reporting utilities and implementations for the indexing engine.""" + +from .blob_workflow_callbacks import BlobWorkflowCallbacks +from .console_workflow_callbacks import ConsoleWorkflowCallbacks +from .file_workflow_callbacks import FileWorkflowCallbacks +from .load_pipeline_reporter import load_pipeline_reporter +from .progress_workflow_callbacks import ProgressWorkflowCallbacks + +__all__ = [ + "BlobWorkflowCallbacks", + "ConsoleWorkflowCallbacks", + "FileWorkflowCallbacks", + "ProgressWorkflowCallbacks", + "load_pipeline_reporter", +] diff --git a/func-app/graphrag/index/reporting/blob_workflow_callbacks.py b/func-app/graphrag/index/reporting/blob_workflow_callbacks.py new file mode 100644 index 0000000000..28f0b6d991 --- /dev/null +++ b/func-app/graphrag/index/reporting/blob_workflow_callbacks.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A reporter that writes to a blob storage.""" + +import json +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from datashaper import NoopWorkflowCallbacks + + +class BlobWorkflowCallbacks(NoopWorkflowCallbacks): + """A reporter that writes to a blob storage.""" + + _blob_service_client: BlobServiceClient + _container_name: str + _max_block_count: int = 25000 # 25k blocks per blob + + def __init__( + self, + connection_string: str | None, + container_name: str, + blob_name: str = "", + base_dir: str | None = None, + storage_account_blob_url: str | None = None, + ): # type: ignore + """Create a new instance of the BlobStorageReporter class.""" + if container_name is None: + msg = "No container name provided for blob storage." + raise ValueError(msg) + if connection_string is None and storage_account_blob_url is None: + msg = "No storage account blob url provided for blob storage." + raise ValueError(msg) + self._connection_string = connection_string + self._storage_account_blob_url = storage_account_blob_url + if self._connection_string: + self._blob_service_client = BlobServiceClient.from_connection_string( + self._connection_string + ) + else: + if storage_account_blob_url is None: + msg = "Either connection_string or storage_account_blob_url must be provided." + raise ValueError(msg) + + self._blob_service_client = BlobServiceClient( + storage_account_blob_url, + credential=DefaultAzureCredential(), + ) + + if blob_name == "": + blob_name = f"report/{datetime.now(tz=timezone.utc).strftime('%Y-%m-%d-%H:%M:%S:%f')}.logs.json" + + self._blob_name = str(Path(base_dir or "") / blob_name) + self._container_name = container_name + self._blob_client = self._blob_service_client.get_blob_client( + self._container_name, self._blob_name + ) + if not self._blob_client.exists(): + self._blob_client.create_append_blob() + + self._num_blocks = 0 # refresh block counter + + def _write_log(self, log: dict[str, Any]): + # create a new file when block count hits close 25k + if ( + self._num_blocks >= self._max_block_count + ): # Check if block count exceeds 25k + self.__init__( + self._connection_string, + self._container_name, + storage_account_blob_url=self._storage_account_blob_url, + ) + + blob_client = self._blob_service_client.get_blob_client( + self._container_name, self._blob_name + ) + blob_client.append_block(json.dumps(log) + "\n") + + # update the blob's block count + self._num_blocks += 1 + + def on_error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ): + """Report an error.""" + self._write_log({ + "type": "error", + "data": message, + "cause": str(cause), + "stack": stack, + "details": details, + }) + + def on_warning(self, message: str, details: dict | None = None): + """Report a warning.""" + self._write_log({"type": "warning", "data": message, "details": details}) + + def on_log(self, message: str, details: dict | None = None): + """Report a generic log message.""" + self._write_log({"type": "log", "data": message, "details": details}) diff --git a/func-app/graphrag/index/reporting/console_workflow_callbacks.py b/func-app/graphrag/index/reporting/console_workflow_callbacks.py new file mode 100644 index 0000000000..b1ab1278f7 --- /dev/null +++ b/func-app/graphrag/index/reporting/console_workflow_callbacks.py @@ -0,0 +1,32 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Console-based reporter for the workflow engine.""" + +from datashaper import NoopWorkflowCallbacks + + +class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks): + """A reporter that writes to a console.""" + + def on_error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ): + """Handle when an error occurs.""" + print(message, str(cause), stack, details) # noqa T201 + + def on_warning(self, message: str, details: dict | None = None): + """Handle when a warning occurs.""" + _print_warning(message) + + def on_log(self, message: str, details: dict | None = None): + """Handle when a log message is produced.""" + print(message, details) # noqa T201 + + +def _print_warning(skk): + print("\033[93m {}\033[00m".format(skk)) # noqa T201 diff --git a/func-app/graphrag/index/reporting/file_workflow_callbacks.py b/func-app/graphrag/index/reporting/file_workflow_callbacks.py new file mode 100644 index 0000000000..e659c4f644 --- /dev/null +++ b/func-app/graphrag/index/reporting/file_workflow_callbacks.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A reporter that writes to a file.""" + +import json +import logging +from io import TextIOWrapper +from pathlib import Path + +from datashaper import NoopWorkflowCallbacks + +log = logging.getLogger(__name__) + + +class FileWorkflowCallbacks(NoopWorkflowCallbacks): + """A reporter that writes to a file.""" + + _out_stream: TextIOWrapper + + def __init__(self, directory: str): + """Create a new file-based workflow reporter.""" + Path(directory).mkdir(parents=True, exist_ok=True) + self._out_stream = open( # noqa: PTH123, SIM115 + Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict" + ) + + def on_error( + self, + message: str, + cause: BaseException | None = None, + stack: str | None = None, + details: dict | None = None, + ): + """Handle when an error occurs.""" + self._out_stream.write( + json.dumps({ + "type": "error", + "data": message, + "stack": stack, + "source": str(cause), + "details": details, + }) + + "\n" + ) + message = f"{message} details={details}" + log.info(message) + + def on_warning(self, message: str, details: dict | None = None): + """Handle when a warning occurs.""" + self._out_stream.write( + json.dumps({"type": "warning", "data": message, "details": details}) + "\n" + ) + _print_warning(message) + + def on_log(self, message: str, details: dict | None = None): + """Handle when a log message is produced.""" + self._out_stream.write( + json.dumps({"type": "log", "data": message, "details": details}) + "\n" + ) + + message = f"{message} details={details}" + log.info(message) + + +def _print_warning(skk): + log.warning(skk) diff --git a/func-app/graphrag/index/reporting/load_pipeline_reporter.py b/func-app/graphrag/index/reporting/load_pipeline_reporter.py new file mode 100644 index 0000000000..0386ea03d1 --- /dev/null +++ b/func-app/graphrag/index/reporting/load_pipeline_reporter.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Load pipeline reporter method.""" + +from pathlib import Path +from typing import cast + +from datashaper import WorkflowCallbacks + +from graphrag.config import ReportingType +from graphrag.index.config import ( + PipelineBlobReportingConfig, + PipelineFileReportingConfig, + PipelineReportingConfig, +) + +from .blob_workflow_callbacks import BlobWorkflowCallbacks +from .console_workflow_callbacks import ConsoleWorkflowCallbacks +from .file_workflow_callbacks import FileWorkflowCallbacks + + +def load_pipeline_reporter( + config: PipelineReportingConfig | None, root_dir: str | None +) -> WorkflowCallbacks: + """Create a reporter for the given pipeline config.""" + config = config or PipelineFileReportingConfig(base_dir="reports") + + match config.type: + case ReportingType.file: + config = cast(PipelineFileReportingConfig, config) + return FileWorkflowCallbacks( + str(Path(root_dir or "") / (config.base_dir or "")) + ) + case ReportingType.console: + return ConsoleWorkflowCallbacks() + case ReportingType.blob: + config = cast(PipelineBlobReportingConfig, config) + return BlobWorkflowCallbacks( + config.connection_string, + config.container_name, + base_dir=config.base_dir, + storage_account_blob_url=config.storage_account_blob_url, + ) + case _: + msg = f"Unknown reporting type: {config.type}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/reporting/progress_workflow_callbacks.py b/func-app/graphrag/index/reporting/progress_workflow_callbacks.py new file mode 100644 index 0000000000..68e407e223 --- /dev/null +++ b/func-app/graphrag/index/reporting/progress_workflow_callbacks.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A workflow callback manager that emits updates to a ProgressReporter.""" + +from typing import Any + +from datashaper import ExecutionNode, NoopWorkflowCallbacks, Progress, TableContainer + +from graphrag.common.progress import ProgressReporter + + +class ProgressWorkflowCallbacks(NoopWorkflowCallbacks): + """A callbackmanager that delegates to a ProgressReporter.""" + + _root_progress: ProgressReporter + _progress_stack: list[ProgressReporter] + + def __init__(self, progress: ProgressReporter) -> None: + """Create a new ProgressWorkflowCallbacks.""" + self._progress = progress + self._progress_stack = [progress] + + def _pop(self) -> None: + self._progress_stack.pop() + + def _push(self, name: str) -> None: + self._progress_stack.append(self._latest.child(name)) + + @property + def _latest(self) -> ProgressReporter: + return self._progress_stack[-1] + + def on_workflow_start(self, name: str, instance: object) -> None: + """Execute this callback when a workflow starts.""" + self._push(name) + + def on_workflow_end(self, name: str, instance: object) -> None: + """Execute this callback when a workflow ends.""" + self._pop() + + def on_step_start(self, node: ExecutionNode, inputs: dict[str, Any]) -> None: + """Execute this callback every time a step starts.""" + verb_id_str = f" ({node.node_id})" if node.has_explicit_id else "" + self._push(f"Verb {node.verb.name}{verb_id_str}") + self._latest(Progress(percent=0)) + + def on_step_end(self, node: ExecutionNode, result: TableContainer | None) -> None: + """Execute this callback every time a step ends.""" + self._pop() + + def on_step_progress(self, node: ExecutionNode, progress: Progress) -> None: + """Handle when progress occurs.""" + self._latest(progress) diff --git a/func-app/graphrag/index/run.py b/func-app/graphrag/index/run.py new file mode 100644 index 0000000000..27416f7d7f --- /dev/null +++ b/func-app/graphrag/index/run.py @@ -0,0 +1,471 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Different methods to run the pipeline.""" + +import gc +import json +import logging +import time +import traceback +from collections.abc import AsyncIterable +from dataclasses import asdict +from io import BytesIO +from pathlib import Path +from string import Template +from typing import cast + +import pandas as pd +from datashaper import ( + DEFAULT_INPUT_NAME, + MemoryProfile, + Workflow, + WorkflowCallbacks, + WorkflowCallbacksManager, + WorkflowRunResult, +) +from graphrag.config.models.graphdb_config import GraphDBConfig + +from .cache import InMemoryCache, PipelineCache, load_cache +from .config import ( + PipelineBlobCacheConfig, + PipelineBlobReportingConfig, + PipelineBlobStorageConfig, + PipelineCacheConfigTypes, + PipelineConfig, + PipelineFileCacheConfig, + PipelineFileReportingConfig, + PipelineFileStorageConfig, + PipelineInputConfigTypes, + PipelineMemoryCacheConfig, + PipelineReportingConfigTypes, + PipelineStorageConfigTypes, + PipelineWorkflowReference, + PipelineWorkflowStep, +) +from .context import PipelineRunContext, PipelineRunStats +from .emit import TableEmitterType, create_table_emitters +from .input import load_input +from .load_pipeline_config import load_pipeline_config +from graphrag.common.progress import NullProgressReporter, ProgressReporter +from .reporting import ( + ConsoleWorkflowCallbacks, + ProgressWorkflowCallbacks, + load_pipeline_reporter, +) +from graphrag.common.storage import MemoryPipelineStorage, PipelineStorage, load_storage +from .typing import PipelineRunResult + +# Register all verbs +from .verbs import * # noqa +from .workflows import ( + VerbDefinitions, + WorkflowDefinitions, + create_workflow, + load_workflows, +) + +log = logging.getLogger(__name__) + + +async def run_pipeline_with_config( + config_or_path: PipelineConfig | str, + workflows: list[PipelineWorkflowReference] | None = None, + dataset: pd.DataFrame | None = None, + storage: PipelineStorage | None = None, + cache: PipelineCache | None = None, + callbacks: WorkflowCallbacks | None = None, + progress_reporter: ProgressReporter | None = None, + input_post_process_steps: list[PipelineWorkflowStep] | None = None, + additional_verbs: VerbDefinitions | None = None, + additional_workflows: WorkflowDefinitions | None = None, + emit: list[TableEmitterType] | None = None, + memory_profile: bool = False, + run_id: str | None = None, + is_resume_run: bool = False, + context_id: str | None = None, + **_kwargs: dict, +) -> AsyncIterable[PipelineRunResult]: + """Run a pipeline with the given config. + + Args: + - config_or_path - The config to run the pipeline with + - workflows - The workflows to run (this overrides the config) + - dataset - The dataset to run the pipeline on (this overrides the config) + - storage - The storage to use for the pipeline (this overrides the config) + - cache - The cache to use for the pipeline (this overrides the config) + - reporter - The reporter to use for the pipeline (this overrides the config) + - input_post_process_steps - The post process steps to run on the input data (this overrides the config) + - additional_verbs - The custom verbs to use for the pipeline. + - additional_workflows - The custom workflows to use for the pipeline. + - emit - The table emitters to use for the pipeline. + - memory_profile - Whether or not to profile the memory. + - run_id - The run id to start or resume from. + """ + if isinstance(config_or_path, str): + log.info("Running pipeline with config %s", config_or_path) + else: + log.info("Running pipeline") + + run_id = run_id or time.strftime("%Y%m%d-%H%M%S") + config = load_pipeline_config(config_or_path) + config = _apply_substitutions(config, run_id) + root_dir = config.root_dir + + def _create_storage(config: PipelineStorageConfigTypes | None) -> PipelineStorage: + return load_storage( + config + or PipelineFileStorageConfig(base_dir=str(Path(root_dir or "") / "output")) + ) + + def _create_cache(config: PipelineCacheConfigTypes | None) -> PipelineCache: + return load_cache(config or PipelineMemoryCacheConfig(), root_dir=root_dir) + + def _create_reporter( + config: PipelineReportingConfigTypes | None, + ) -> WorkflowCallbacks | None: + return load_pipeline_reporter(config, root_dir) if config else None + + async def _create_input( + config: PipelineInputConfigTypes | None, + ) -> pd.DataFrame | None: + if config is None: + return None + + return await load_input(config, progress_reporter, root_dir) + + def _create_postprocess_steps( + config: PipelineInputConfigTypes | None, + ) -> list[PipelineWorkflowStep] | None: + return config.post_process if config is not None else None + + progress_reporter = progress_reporter or NullProgressReporter() + storage = storage or _create_storage(config.storage) + cache = cache or _create_cache(config.cache) + callbacks = callbacks or _create_reporter(config.reporting) + dataset = dataset if dataset is not None else await _create_input(config.input) + post_process_steps = input_post_process_steps or _create_postprocess_steps( + config.input + ) + workflows = workflows or config.workflows + + if dataset is None: + msg = "No dataset provided!" + raise ValueError(msg) + + async for table in run_pipeline( + workflows=workflows, + dataset=dataset, + storage=storage, + cache=cache, + callbacks=callbacks, + input_post_process_steps=post_process_steps, + memory_profile=memory_profile, + additional_verbs=additional_verbs, + additional_workflows=additional_workflows, + progress_reporter=progress_reporter, + emit=emit, + is_resume_run=is_resume_run, + graphdb_params = config.graphdb_params, + context_id=context_id, + ): + yield table + + +async def run_pipeline( + workflows: list[PipelineWorkflowReference], + dataset: pd.DataFrame, + storage: PipelineStorage | None = None, + cache: PipelineCache | None = None, + callbacks: WorkflowCallbacks | None = None, + progress_reporter: ProgressReporter | None = None, + input_post_process_steps: list[PipelineWorkflowStep] | None = None, + additional_verbs: VerbDefinitions | None = None, + additional_workflows: WorkflowDefinitions | None = None, + emit: list[TableEmitterType] | None = None, + memory_profile: bool = False, + is_resume_run: bool = False, + graphdb_params: GraphDBConfig|None = None, + context_id: str | None = None, + **_kwargs: dict, +) -> AsyncIterable[PipelineRunResult]: + """Run the pipeline. + + Args: + - workflows - The workflows to run + - dataset - The dataset to run the pipeline on, specifically a dataframe with the following columns at a minimum: + - id - The id of the document + - text - The text of the document + - title - The title of the document + These must exist after any post process steps are run if there are any! + - storage - The storage to use for the pipeline + - cache - The cache to use for the pipeline + - reporter - The reporter to use for the pipeline + - input_post_process_steps - The post process steps to run on the input data + - additional_verbs - The custom verbs to use for the pipeline + - additional_workflows - The custom workflows to use for the pipeline + - debug - Whether or not to run in debug mode + Returns: + - output - An iterable of workflow results as they complete running, as well as any errors that occur + """ + start_time = time.time() + stats = PipelineRunStats() + storage = storage or MemoryPipelineStorage() + cache = cache or InMemoryCache() + progress_reporter = progress_reporter or NullProgressReporter() + callbacks = callbacks or ConsoleWorkflowCallbacks() + callbacks = _create_callback_chain(callbacks, progress_reporter) + emit = emit or [TableEmitterType.Parquet] + emitters = create_table_emitters( + emit, + storage, + lambda e, s, d: cast(WorkflowCallbacks, callbacks).on_error( + "Error emitting table", e, s, d + ), + graphdb_params, + context_id, + ) + loaded_workflows = load_workflows( + workflows, + additional_verbs=additional_verbs, + additional_workflows=additional_workflows, + memory_profile=memory_profile, + ) + workflows_to_run = loaded_workflows.workflows + workflow_dependencies = loaded_workflows.dependencies + + context = _create_run_context(storage, cache, stats) + + if len(emitters) == 0: + log.info( + "No emitters provided. No table outputs will be generated. This is probably not correct." + ) + + async def dump_stats() -> None: + await storage.set("stats.json", json.dumps(asdict(stats), indent=4)) + + async def load_table_from_storage(name: str) -> pd.DataFrame: + if not await storage.has(name): + msg = f"Could not find {name} in storage!" + raise ValueError(msg) + try: + log.info("read table from storage: %s", name) + return pd.read_parquet(BytesIO(await storage.get(name, as_bytes=True))) + except Exception: + log.exception("error loading table from storage: %s", name) + raise + + async def inject_workflow_data_dependencies(workflow: Workflow) -> None: + workflow.add_table(DEFAULT_INPUT_NAME, dataset) + deps = workflow_dependencies[workflow.name] + log.info("dependencies for %s: %s", workflow.name, deps) + for id in deps: + workflow_id = f"workflow:{id}" + table = await load_table_from_storage(f"{id}.parquet") + workflow.add_table(workflow_id, table) + + async def write_workflow_stats( + workflow: Workflow, + workflow_result: WorkflowRunResult, + workflow_start_time: float, + ) -> None: + for vt in workflow_result.verb_timings: + stats.workflows[workflow.name][f"{vt.index}_{vt.verb}"] = vt.timing + + workflow_end_time = time.time() + stats.workflows[workflow.name]["overall"] = ( + workflow_end_time - workflow_start_time + ) + stats.total_runtime = time.time() - start_time + await dump_stats() + + if workflow_result.memory_profile is not None: + await _save_profiler_stats( + storage, workflow.name, workflow_result.memory_profile + ) + + log.debug( + "first row of %s => %s", workflow_name, workflow.output().iloc[0].to_json() + ) + + async def emit_workflow_output(workflow: Workflow) -> pd.DataFrame: + output = cast(pd.DataFrame, workflow.output()) + for emitter in emitters: + await emitter.emit(workflow.name, output) + return output + + dataset = await _run_post_process_steps( + input_post_process_steps, dataset, context, callbacks + ) + + # Make sure the incoming data is valid + _validate_dataset(dataset) + + log.info("Final # of rows loaded: %s", len(dataset)) + stats.num_documents = len(dataset) + last_workflow = "input" + + try: + await dump_stats() + + for workflow_to_run in workflows_to_run: + # Try to flush out any intermediate dataframes + gc.collect() + + workflow = workflow_to_run.workflow + workflow_name: str = workflow.name + last_workflow = workflow_name + + log.info("Running workflow: %s...", workflow_name) + + if is_resume_run and await storage.has( + f"{workflow_to_run.workflow.name}.parquet" + ): + log.info("Skipping %s because it already exists", workflow_name) + continue + + stats.workflows[workflow_name] = {"overall": 0.0} + await inject_workflow_data_dependencies(workflow) + + workflow_start_time = time.time() + result = await workflow.run(context, callbacks) + await write_workflow_stats(workflow, result, workflow_start_time) + + # Save the output from the workflow + output = await emit_workflow_output(workflow) + yield PipelineRunResult(workflow_name, output, None) + output = None + workflow.dispose() + workflow = None + + stats.total_runtime = time.time() - start_time + await dump_stats() + except Exception as e: + log.exception("error running workflow %s", last_workflow) + cast(WorkflowCallbacks, callbacks).on_error( + "Error running pipeline!", e, traceback.format_exc() + ) + yield PipelineRunResult(last_workflow, None, [e]) + + +def _create_callback_chain( + callbacks: WorkflowCallbacks | None, progress: ProgressReporter | None +) -> WorkflowCallbacks: + """Create a callbacks manager.""" + manager = WorkflowCallbacksManager() + if callbacks is not None: + manager.register(callbacks) + if progress is not None: + manager.register(ProgressWorkflowCallbacks(progress)) + return manager + + +async def _save_profiler_stats( + storage: PipelineStorage, workflow_name: str, profile: MemoryProfile +): + """Save the profiler stats to the storage.""" + await storage.set( + f"{workflow_name}_profiling.peak_stats.csv", + profile.peak_stats.to_csv(index=True), + ) + + await storage.set( + f"{workflow_name}_profiling.snapshot_stats.csv", + profile.snapshot_stats.to_csv(index=True), + ) + + await storage.set( + f"{workflow_name}_profiling.time_stats.csv", + profile.time_stats.to_csv(index=True), + ) + + await storage.set( + f"{workflow_name}_profiling.detailed_view.csv", + profile.detailed_view.to_csv(index=True), + ) + + +async def _run_post_process_steps( + post_process: list[PipelineWorkflowStep] | None, + dataset: pd.DataFrame, + context: PipelineRunContext, + callbacks: WorkflowCallbacks, +) -> pd.DataFrame: + """Run the pipeline. + + Args: + - post_process - The post process steps to run + - dataset - The dataset to run the steps on + - context - The pipeline run context + Returns: + - output - The dataset after running the post process steps + """ + if post_process is not None and len(post_process) > 0: + input_workflow = create_workflow( + "Input Post Process", + post_process, + ) + input_workflow.add_table(DEFAULT_INPUT_NAME, dataset) + await input_workflow.run( + context=context, + callbacks=callbacks, + ) + dataset = cast(pd.DataFrame, input_workflow.output()) + return dataset + + +def _validate_dataset(dataset: pd.DataFrame): + """Validate the dataset for the pipeline. + + Args: + - dataset - The dataset to validate + """ + if not isinstance(dataset, pd.DataFrame): + msg = "Dataset must be a pandas dataframe!" + raise TypeError(msg) + + +def _apply_substitutions(config: PipelineConfig, run_id: str) -> PipelineConfig: + substitutions = {"timestamp": run_id} + + if ( + isinstance( + config.storage, PipelineFileStorageConfig | PipelineBlobStorageConfig + ) + and config.storage.base_dir + ): + config.storage.base_dir = Template(config.storage.base_dir).substitute( + substitutions + ) + if ( + isinstance(config.cache, PipelineFileCacheConfig | PipelineBlobCacheConfig) + and config.cache.base_dir + ): + config.cache.base_dir = Template(config.cache.base_dir).substitute( + substitutions + ) + + if ( + isinstance( + config.reporting, PipelineFileReportingConfig | PipelineBlobReportingConfig + ) + and config.reporting.base_dir + ): + config.reporting.base_dir = Template(config.reporting.base_dir).substitute( + substitutions + ) + + return config + + +def _create_run_context( + storage: PipelineStorage, + cache: PipelineCache, + stats: PipelineRunStats, +) -> PipelineRunContext: + """Create the run context for the pipeline.""" + return PipelineRunContext( + stats=stats, + cache=cache, + storage=storage, + ) diff --git a/func-app/graphrag/index/text_splitting/__init__.py b/func-app/graphrag/index/text_splitting/__init__.py new file mode 100644 index 0000000000..4653adb22b --- /dev/null +++ b/func-app/graphrag/index/text_splitting/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine Text Splitting package root.""" + +from .check_token_limit import check_token_limit +from .text_splitting import ( + DecodeFn, + EncodedText, + EncodeFn, + LengthFn, + NoopTextSplitter, + TextListSplitter, + TextListSplitterType, + TextSplitter, + Tokenizer, + TokenTextSplitter, + split_text_on_tokens, +) + +__all__ = [ + "DecodeFn", + "EncodeFn", + "EncodedText", + "LengthFn", + "NoopTextSplitter", + "TextListSplitter", + "TextListSplitterType", + "TextSplitter", + "TokenTextSplitter", + "Tokenizer", + "check_token_limit", + "split_text_on_tokens", +] diff --git a/func-app/graphrag/index/text_splitting/check_token_limit.py b/func-app/graphrag/index/text_splitting/check_token_limit.py new file mode 100644 index 0000000000..1a5f862254 --- /dev/null +++ b/func-app/graphrag/index/text_splitting/check_token_limit.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Token limit method definition.""" + +from .text_splitting import TokenTextSplitter + + +def check_token_limit(text, max_token): + """Check token limit.""" + text_splitter = TokenTextSplitter(chunk_size=max_token, chunk_overlap=0) + docs = text_splitter.split_text(text) + if len(docs) > 1: + return 0 + return 1 diff --git a/func-app/graphrag/index/text_splitting/text_splitting.py b/func-app/graphrag/index/text_splitting/text_splitting.py new file mode 100644 index 0000000000..0badc8977c --- /dev/null +++ b/func-app/graphrag/index/text_splitting/text_splitting.py @@ -0,0 +1,244 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models.""" + +import json +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable, Collection, Iterable +from dataclasses import dataclass +from enum import Enum +from typing import Any, Literal, cast + +import pandas as pd +import tiktoken + +from graphrag.index.utils import num_tokens_from_string + +EncodedText = list[int] +DecodeFn = Callable[[EncodedText], str] +EncodeFn = Callable[[str], EncodedText] +LengthFn = Callable[[str], int] + +log = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class Tokenizer: + """Tokenizer data class.""" + + chunk_overlap: int + """Overlap in tokens between chunks""" + tokens_per_chunk: int + """Maximum number of tokens per chunk""" + decode: DecodeFn + """ Function to decode a list of token ids to a string""" + encode: EncodeFn + """ Function to encode a string to a list of token ids""" + + +class TextSplitter(ABC): + """Text splitter class definition.""" + + _chunk_size: int + _chunk_overlap: int + _length_function: LengthFn + _keep_separator: bool + _add_start_index: bool + _strip_whitespace: bool + + def __init__( + self, + # based on text-ada-002-embedding max input buffer length + # https://platform.openai.com/docs/guides/embeddings/second-generation-models + chunk_size: int = 8191, + chunk_overlap: int = 100, + length_function: LengthFn = len, + keep_separator: bool = False, + add_start_index: bool = False, + strip_whitespace: bool = True, + ): + """Init method definition.""" + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._length_function = length_function + self._keep_separator = keep_separator + self._add_start_index = add_start_index + self._strip_whitespace = strip_whitespace + + @abstractmethod + def split_text(self, text: str | list[str]) -> Iterable[str]: + """Split text method definition.""" + + +class NoopTextSplitter(TextSplitter): + """Noop text splitter class definition.""" + + def split_text(self, text: str | list[str]) -> Iterable[str]: + """Split text method definition.""" + return [text] if isinstance(text, str) else text + + +class TokenTextSplitter(TextSplitter): + """Token text splitter class definition.""" + + _allowed_special: Literal["all"] | set[str] + _disallowed_special: Literal["all"] | Collection[str] + + def __init__( + self, + encoding_name: str = "cl100k_base", + model_name: str | None = None, + allowed_special: Literal["all"] | set[str] | None = None, + disallowed_special: Literal["all"] | Collection[str] = "all", + **kwargs: Any, + ): + """Init method definition.""" + super().__init__(**kwargs) + if model_name is not None: + try: + enc = tiktoken.encoding_for_model(model_name) + except KeyError: + log.exception("Model %s not found, using %s", model_name, encoding_name) + enc = tiktoken.get_encoding(encoding_name) + else: + enc = tiktoken.get_encoding(encoding_name) + self._tokenizer = enc + self._allowed_special = allowed_special or set() + self._disallowed_special = disallowed_special + + def encode(self, text: str) -> list[int]: + """Encode the given text into an int-vector.""" + return self._tokenizer.encode( + text, + allowed_special=self._allowed_special, + disallowed_special=self._disallowed_special, + ) + + def num_tokens(self, text: str) -> int: + """Return the number of tokens in a string.""" + return len(self.encode(text)) + + def split_text(self, text: str | list[str]) -> list[str]: + """Split text method.""" + if cast(bool, pd.isna(text)) or text == "": + return [] + if isinstance(text, list): + text = " ".join(text) + if not isinstance(text, str): + msg = f"Attempting to split a non-string value, actual is {type(text)}" + raise TypeError(msg) + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self._chunk_size, + decode=self._tokenizer.decode, + encode=lambda text: self.encode(text), + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + +class TextListSplitterType(str, Enum): + """Enum for the type of the TextListSplitter.""" + + DELIMITED_STRING = "delimited_string" + JSON = "json" + + +class TextListSplitter(TextSplitter): + """Text list splitter class definition.""" + + def __init__( + self, + chunk_size: int, + splitter_type: TextListSplitterType = TextListSplitterType.JSON, + input_delimiter: str | None = None, + output_delimiter: str | None = None, + model_name: str | None = None, + encoding_name: str | None = None, + ): + """Initialize the TextListSplitter with a chunk size.""" + # Set the chunk overlap to 0 as we use full strings + super().__init__(chunk_size, chunk_overlap=0) + self._type = splitter_type + self._input_delimiter = input_delimiter + self._output_delimiter = output_delimiter or "\n" + self._length_function = lambda x: num_tokens_from_string( + x, model=model_name, encoding_name=encoding_name + ) + + def split_text(self, text: str | list[str]) -> Iterable[str]: + """Split a string list into a list of strings for a given chunk size.""" + if not text: + return [] + + result: list[str] = [] + current_chunk: list[str] = [] + + # Add the brackets + current_length: int = self._length_function("[]") + + # Input should be a string list joined by a delimiter + string_list = self._load_text_list(text) + + if len(string_list) == 1: + return string_list + + for item in string_list: + # Count the length of the item and add comma + item_length = self._length_function(f"{item},") + + if current_length + item_length > self._chunk_size: + if current_chunk and len(current_chunk) > 0: + # Add the current chunk to the result + self._append_to_result(result, current_chunk) + + # Start a new chunk + current_chunk = [item] + # Add 2 for the brackets + current_length = item_length + else: + # Add the item to the current chunk + current_chunk.append(item) + # Add 1 for the comma + current_length += item_length + + # Add the last chunk to the result + self._append_to_result(result, current_chunk) + + return result + + def _load_text_list(self, text: str | list[str]): + """Load the text list based on the type.""" + if isinstance(text, list): + string_list = text + elif self._type == TextListSplitterType.JSON: + string_list = json.loads(text) + else: + string_list = text.split(self._input_delimiter) + return string_list + + def _append_to_result(self, chunk_list: list[str], new_chunk: list[str]): + """Append the current chunk to the result.""" + if new_chunk and len(new_chunk) > 0: + if self._type == TextListSplitterType.JSON: + chunk_list.append(json.dumps(new_chunk)) + else: + chunk_list.append(self._output_delimiter.join(new_chunk)) + + +def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: + """Split incoming text and return chunks using tokenizer.""" + splits: list[str] = [] + input_ids = tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(tokenizer.decode(chunk_ids)) + start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits diff --git a/func-app/graphrag/index/typing.py b/func-app/graphrag/index/typing.py new file mode 100644 index 0000000000..ed1d7e93e7 --- /dev/null +++ b/func-app/graphrag/index/typing.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the 'PipelineRunResult' model.""" + +from collections.abc import Callable +from dataclasses import dataclass + +import pandas as pd + +ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] + + +@dataclass +class PipelineRunResult: + """Pipeline run result class definition.""" + + workflow: str + result: pd.DataFrame | None + errors: list[BaseException] | None diff --git a/func-app/graphrag/index/utils/__init__.py b/func-app/graphrag/index/utils/__init__.py new file mode 100644 index 0000000000..7cbbb53d75 --- /dev/null +++ b/func-app/graphrag/index/utils/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utils methods definition.""" + +from .dicts import dict_has_keys_with_types +from .hashing import gen_md5_hash +from .is_null import is_null +from .load_graph import load_graph +from .string import clean_str +from .tokens import num_tokens_from_string, string_from_tokens +from .topological_sort import topological_sort +from .uuid import gen_uuid + +__all__ = [ + "clean_str", + "dict_has_keys_with_types", + "gen_md5_hash", + "gen_uuid", + "is_null", + "load_graph", + "num_tokens_from_string", + "string_from_tokens", + "topological_sort", +] diff --git a/func-app/graphrag/index/utils/dataframes.py b/func-app/graphrag/index/utils/dataframes.py new file mode 100644 index 0000000000..ea65d71d7a --- /dev/null +++ b/func-app/graphrag/index/utils/dataframes.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing DataFrame utilities.""" + +from collections.abc import Callable +from typing import Any, cast + +import pandas as pd +from pandas._typing import MergeHow + + +def drop_columns(df: pd.DataFrame, *column: str) -> pd.DataFrame: + """Drop columns from a dataframe.""" + return df.drop(list(column), axis=1) + + +def where_column_equals(df: pd.DataFrame, column: str, value: Any) -> pd.DataFrame: + """Return a filtered DataFrame where a column equals a value.""" + return cast(pd.DataFrame, df[df[column] == value]) + + +def antijoin(df: pd.DataFrame, exclude: pd.DataFrame, column: str) -> pd.DataFrame: + """Return an anti-joined dataframe. + + Arguments: + * df: The DataFrame to apply the exclusion to + * exclude: The DataFrame containing rows to remove. + * column: The join-on column. + """ + result = df.merge( + exclude[[column]], + on=column, + how="outer", + indicator=True, + ) + if "_merge" in result.columns: + result = result[result["_merge"] == "left_only"].drop("_merge", axis=1) + return cast(pd.DataFrame, result) + + +def transform_series(series: pd.Series, fn: Callable[[Any], Any]) -> pd.Series: + """Apply a transformation function to a series.""" + return cast(pd.Series, series.apply(fn)) + + +def join( + left: pd.DataFrame, right: pd.DataFrame, key: str, strategy: MergeHow = "left" +) -> pd.DataFrame: + """Perform a table join.""" + return left.merge(right, on=key, how=strategy) + + +def union(*frames: pd.DataFrame) -> pd.DataFrame: + """Perform a union operation on the given set of dataframes.""" + return pd.concat(list(frames)) + + +def select(df: pd.DataFrame, *columns: str) -> pd.DataFrame: + """Select columns from a dataframe.""" + return cast(pd.DataFrame, df[list(columns)]) diff --git a/func-app/graphrag/index/utils/dicts.py b/func-app/graphrag/index/utils/dicts.py new file mode 100644 index 0000000000..4d3662e0b8 --- /dev/null +++ b/func-app/graphrag/index/utils/dicts.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A utility module containing methods for inspecting and verifying dictionary types.""" + + +def dict_has_keys_with_types( + data: dict, expected_fields: list[tuple[str, type]] +) -> bool: + """Return True if the given dictionary has the given keys with the given types.""" + for field, field_type in expected_fields: + if field not in data: + return False + + value = data[field] + if not isinstance(value, field_type): + return False + return True diff --git a/func-app/graphrag/index/utils/ds_util.py b/func-app/graphrag/index/utils/ds_util.py new file mode 100644 index 0000000000..d65c69e4a8 --- /dev/null +++ b/func-app/graphrag/index/utils/ds_util.py @@ -0,0 +1,32 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A utility module datashaper-specific utility methods.""" + +from typing import cast + +from datashaper import TableContainer, VerbInput + +_NAMED_INPUTS_REQUIRED = "Named inputs are required" + + +def get_required_input_table(input: VerbInput, name: str) -> TableContainer: + """Get a required input table by name.""" + return cast(TableContainer, get_named_input_table(input, name, required=True)) + + +def get_named_input_table( + input: VerbInput, name: str, required: bool = False +) -> TableContainer | None: + """Get an input table from datashaper verb-inputs by name.""" + named_inputs = input.named + if named_inputs is None: + if not required: + return None + raise ValueError(_NAMED_INPUTS_REQUIRED) + + result = named_inputs.get(name) + if result is None and required: + msg = f"input '${name}' is required" + raise ValueError(msg) + return result diff --git a/func-app/graphrag/index/utils/hashing.py b/func-app/graphrag/index/utils/hashing.py new file mode 100644 index 0000000000..342ae99d44 --- /dev/null +++ b/func-app/graphrag/index/utils/hashing.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Hashing utilities.""" + +from collections.abc import Iterable +from hashlib import md5 +from typing import Any + + +def gen_md5_hash(item: dict[str, Any], hashcode: Iterable[str]): + """Generate an md5 hash.""" + hashed = "".join([str(item[column]) for column in hashcode]) + return f"{md5(hashed.encode('utf-8'), usedforsecurity=False).hexdigest()}" diff --git a/func-app/graphrag/index/utils/is_null.py b/func-app/graphrag/index/utils/is_null.py new file mode 100644 index 0000000000..f5df1955ae --- /dev/null +++ b/func-app/graphrag/index/utils/is_null.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Defines the is_null utility.""" + +import math +from typing import Any + + +def is_null(value: Any) -> bool: + """Check if value is null or is nan.""" + + def is_none() -> bool: + return value is None + + def is_nan() -> bool: + return isinstance(value, float) and math.isnan(value) + + return is_none() or is_nan() diff --git a/func-app/graphrag/index/utils/load_graph.py b/func-app/graphrag/index/utils/load_graph.py new file mode 100644 index 0000000000..57992a04c8 --- /dev/null +++ b/func-app/graphrag/index/utils/load_graph.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Networkx load_graph utility definition.""" + +import networkx as nx + + +def load_graph(graphml: str | nx.Graph) -> nx.Graph: + """Load a graph from a graphml file or a networkx graph.""" + return nx.parse_graphml(graphml) if isinstance(graphml, str) else graphml diff --git a/func-app/graphrag/index/utils/rate_limiter.py b/func-app/graphrag/index/utils/rate_limiter.py new file mode 100644 index 0000000000..8dc641719b --- /dev/null +++ b/func-app/graphrag/index/utils/rate_limiter.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Rate limiter utility.""" + +import asyncio +import time + + +class RateLimiter: + """ + The original TpmRpmLLMLimiter strategy did not account for minute-based rate limiting when scheduled. + + The RateLimiter was introduced to ensure that the CommunityReportsExtractor could be scheduled to adhere to rate configurations on a per-minute basis. + """ + + # TODO: RateLimiter scheduled: using asyncio for async_mode + + def __init__(self, rate: int, per: int): + self.rate = rate + self.per = per + self.allowance = rate + self.last_check = time.monotonic() + + async def acquire(self): + """Acquire a token from the rate limiter.""" + current = time.monotonic() + elapsed = current - self.last_check + self.last_check = current + self.allowance += elapsed * (self.rate / self.per) + + if self.allowance > self.rate: + self.allowance = self.rate + + if self.allowance < 1.0: + sleep_time = (1.0 - self.allowance) * (self.per / self.rate) + await asyncio.sleep(sleep_time) + self.allowance = 0.0 + else: + self.allowance -= 1.0 diff --git a/func-app/graphrag/index/utils/string.py b/func-app/graphrag/index/utils/string.py new file mode 100644 index 0000000000..7e1654bb4e --- /dev/null +++ b/func-app/graphrag/index/utils/string.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""String utilities.""" + +import html +import re +from typing import Any + + +def clean_str(input: Any) -> str: + """Clean an input string by removing HTML escapes, control characters, and other unwanted characters.""" + # If we get non-string input, just give it back + if not isinstance(input, str): + return input + + result = html.unescape(input.strip()) + # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python + return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) diff --git a/func-app/graphrag/index/utils/tokens.py b/func-app/graphrag/index/utils/tokens.py new file mode 100644 index 0000000000..4a189b9b22 --- /dev/null +++ b/func-app/graphrag/index/utils/tokens.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utilities for working with tokens.""" + +import logging + +import tiktoken + +DEFAULT_ENCODING_NAME = "cl100k_base" +log = logging.getLogger(__name__) + + +def num_tokens_from_string( + string: str, model: str | None = None, encoding_name: str | None = None +) -> int: + """Return the number of tokens in a text string.""" + if model is not None: + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}" + log.warning(msg) + encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME) + else: + encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME) + return len(encoding.encode(string)) + + +def string_from_tokens( + tokens: list[int], model: str | None = None, encoding_name: str | None = None +) -> str: + """Return a text string from a list of tokens.""" + if model is not None: + encoding = tiktoken.encoding_for_model(model) + elif encoding_name is not None: + encoding = tiktoken.get_encoding(encoding_name) + else: + msg = "Either model or encoding_name must be specified." + raise ValueError(msg) + return encoding.decode(tokens) diff --git a/func-app/graphrag/index/utils/topological_sort.py b/func-app/graphrag/index/utils/topological_sort.py new file mode 100644 index 0000000000..a19b464559 --- /dev/null +++ b/func-app/graphrag/index/utils/topological_sort.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Topological sort utility method.""" + +from graphlib import TopologicalSorter + + +def topological_sort(graph: dict[str, list[str]]) -> list[str]: + """Topological sort.""" + ts = TopologicalSorter(graph) + return list(ts.static_order()) diff --git a/func-app/graphrag/index/utils/uuid.py b/func-app/graphrag/index/utils/uuid.py new file mode 100644 index 0000000000..0671fb09da --- /dev/null +++ b/func-app/graphrag/index/utils/uuid.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""UUID utilities.""" + +import uuid +from random import Random, getrandbits + + +def gen_uuid(rd: Random | None = None): + """Generate a random UUID v4.""" + return uuid.UUID( + int=rd.getrandbits(128) if rd is not None else getrandbits(128), version=4 + ).hex diff --git a/func-app/graphrag/index/verbs/__init__.py b/func-app/graphrag/index/verbs/__init__.py new file mode 100644 index 0000000000..379c2a3749 --- /dev/null +++ b/func-app/graphrag/index/verbs/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing get_default_verbs method definition.""" + +from .covariates import extract_covariates +from .entities import entity_extract, summarize_descriptions +from .genid import genid +from .graph import ( + cluster_graph, + create_community_reports, + create_graph, + embed_graph, + layout_graph, + merge_graphs, + unpack_graph, +) +from .overrides import aggregate, concat, merge +from .snapshot import snapshot +from .snapshot_rows import snapshot_rows +from .spread_json import spread_json +from .text import chunk, text_embed, text_split, text_translate +from .unzip import unzip +from .zip import zip_verb + +__all__ = [ + "aggregate", + "chunk", + "cluster_graph", + "concat", + "create_community_reports", + "create_graph", + "embed_graph", + "entity_extract", + "extract_covariates", + "genid", + "layout_graph", + "merge", + "merge_graphs", + "snapshot", + "snapshot_rows", + "spread_json", + "summarize_descriptions", + "text_embed", + "text_split", + "text_translate", + "unpack_graph", + "unzip", + "zip_verb", +] diff --git a/func-app/graphrag/index/verbs/covariates/__init__.py b/func-app/graphrag/index/verbs/covariates/__init__.py new file mode 100644 index 0000000000..cdebee228b --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine covariates package root.""" + +from .extract_covariates import extract_covariates + +__all__ = ["extract_covariates"] diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/__init__.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/__init__.py new file mode 100644 index 0000000000..53d357bb46 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text extract claims package root.""" + +from .extract_covariates import ExtractClaimsStrategyType, extract_covariates + +__all__ = ["ExtractClaimsStrategyType", "extract_covariates"] diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py new file mode 100644 index 0000000000..a67cb0fa0e --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/extract_covariates.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the extract_covariates verb definition.""" + +import logging +from dataclasses import asdict +from enum import Enum +from typing import Any, cast + +import pandas as pd +from datashaper import ( + AsyncType, + TableContainer, + VerbCallbacks, + VerbInput, + derive_from_rows, + verb, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.verbs.covariates.typing import Covariate, CovariateExtractStrategy + +log = logging.getLogger(__name__) + + +class ExtractClaimsStrategyType(str, Enum): + """ExtractClaimsStrategyType class definition.""" + + graph_intelligence = "graph_intelligence" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] + + +@verb(name="extract_covariates") +async def extract_covariates( + input: VerbInput, + cache: PipelineCache, + callbacks: VerbCallbacks, + column: str, + covariate_type: str, + strategy: dict[str, Any] | None, + async_mode: AsyncType = AsyncType.AsyncIO, + entity_types: list[str] | None = None, + **kwargs, +) -> TableContainer: + """ + Extract claims from a piece of text. + + ## Usage + TODO + """ + log.debug("extract_covariates strategy=%s", strategy) + if entity_types is None: + entity_types = DEFAULT_ENTITY_TYPES + output = cast(pd.DataFrame, input.get_input()) + + resolved_entities_map = {} + + strategy = strategy or {} + strategy_exec = load_strategy( + strategy.get("type", ExtractClaimsStrategyType.graph_intelligence) + ) + strategy_config = {**strategy} + + async def run_strategy(row): + text = row[column] + result = await strategy_exec( + text, entity_types, resolved_entities_map, callbacks, cache, strategy_config + ) + return [ + create_row_from_claim_data(row, item, covariate_type) + for item in result.covariate_data + ] + + results = await derive_from_rows( + output, + run_strategy, + callbacks, + scheduling_type=async_mode, + num_threads=kwargs.get("num_threads", 4), + ) + output = pd.DataFrame([item for row in results for item in row or []]) + return TableContainer(table=output) + + +def load_strategy(strategy_type: ExtractClaimsStrategyType) -> CovariateExtractStrategy: + """Load strategy method definition.""" + match strategy_type: + case ExtractClaimsStrategyType.graph_intelligence: + from .strategies.graph_intelligence import run as run_gi + + return run_gi + case _: + msg = f"Unknown strategy: {strategy_type}" + raise ValueError(msg) + + +def create_row_from_claim_data(row, covariate_data: Covariate, covariate_type: str): + """Create a row from the claim data and the input row.""" + item = {**row, **asdict(covariate_data), "covariate_type": covariate_type} + # TODO: doc_id from extraction isn't necessary + # since chunking happens before this + del item["doc_id"] + return item diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/__init__.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/__init__.py new file mode 100644 index 0000000000..605c66f8d1 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text extract claims strategies package root.""" diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/__init__.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/__init__.py new file mode 100644 index 0000000000..ab01f06fc4 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text extract claims strategies graph intelligence package root.""" + +from .run_gi_extract_claims import run + +__all__ = ["run"] diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/defaults.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/defaults.py new file mode 100644 index 0000000000..846bfa81e0 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/defaults.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing MOCK_LLM_RESPONSES definition.""" + +MOCK_LLM_RESPONSES = [ + """ +[ + { + "subject": "COMPANY A", + "object": "GOVERNMENT AGENCY B", + "type": "ANTI-COMPETITIVE PRACTICES", + "status": "TRUE", + "start_date": "2022-01-10T00:00:00", + "end_date": "2022-01-10T00:00:00", + "description": "Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10", + "source_text": ["According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B."] + } +] + """.strip() +] diff --git a/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/run_gi_extract_claims.py b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/run_gi_extract_claims.py new file mode 100644 index 0000000000..1c9f058830 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/extract_covariates/strategies/graph_intelligence/run_gi_extract_claims.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _run_chain methods definitions.""" + +from collections.abc import Iterable +from typing import Any + +from datashaper import VerbCallbacks + +import graphrag.config.defaults as defs +from graphrag.config.enums import LLMType +from graphrag.index.cache import PipelineCache +from graphrag.index.graph.extractors.claims import ClaimExtractor +from graphrag.index.llm import load_llm +from graphrag.index.verbs.covariates.typing import ( + Covariate, + CovariateExtractionResult, +) +from graphrag.llm import CompletionLLM + +from .defaults import MOCK_LLM_RESPONSES + + +async def run( + input: str | Iterable[str], + entity_types: list[str], + resolved_entities_map: dict[str, str], + reporter: VerbCallbacks, + pipeline_cache: PipelineCache, + strategy_config: dict[str, Any], +) -> CovariateExtractionResult: + """Run the Claim extraction chain.""" + llm_config = strategy_config.get( + "llm", {"type": LLMType.StaticResponse, "responses": MOCK_LLM_RESPONSES} + ) + llm_type = llm_config.get("type", LLMType.StaticResponse) + llm = load_llm("claim_extraction", llm_type, reporter, pipeline_cache, llm_config) + return await _execute( + llm, input, entity_types, resolved_entities_map, reporter, strategy_config + ) + + +async def _execute( + llm: CompletionLLM, + texts: Iterable[str], + entity_types: list[str], + resolved_entities_map: dict[str, str], + reporter: VerbCallbacks, + strategy_config: dict[str, Any], +) -> CovariateExtractionResult: + extraction_prompt = strategy_config.get("extraction_prompt") + max_gleanings = strategy_config.get("max_gleanings", defs.CLAIM_MAX_GLEANINGS) + tuple_delimiter = strategy_config.get("tuple_delimiter") + record_delimiter = strategy_config.get("record_delimiter") + completion_delimiter = strategy_config.get("completion_delimiter") + encoding_model = strategy_config.get("encoding_name") + + extractor = ClaimExtractor( + llm_invoker=llm, + extraction_prompt=extraction_prompt, + max_gleanings=max_gleanings, + encoding_model=encoding_model, + on_error=lambda e, s, d: ( + reporter.error("Claim Extraction Error", e, s, d) if reporter else None + ), + ) + + claim_description = strategy_config.get("claim_description") + if claim_description is None: + msg = "claim_description is required for claim extraction" + raise ValueError(msg) + + texts = [texts] if isinstance(texts, str) else texts + + results = await extractor({ + "input_text": texts, + "entity_specs": entity_types, + "resolved_entities": resolved_entities_map, + "claim_description": claim_description, + "tuple_delimiter": tuple_delimiter, + "record_delimiter": record_delimiter, + "completion_delimiter": completion_delimiter, + }) + + claim_data = results.output + return CovariateExtractionResult([create_covariate(item) for item in claim_data]) + + +def create_covariate(item: dict[str, Any]) -> Covariate: + """Create a covariate from the item.""" + return Covariate( + subject_id=item.get("subject_id"), + subject_type=item.get("subject_type"), + object_id=item.get("object_id"), + object_type=item.get("object_type"), + type=item.get("type"), + status=item.get("status"), + start_date=item.get("start_date"), + end_date=item.get("end_date"), + description=item.get("description"), + source_text=item.get("source_text"), + doc_id=item.get("doc_id"), + record_id=item.get("record_id"), + id=item.get("id"), + ) diff --git a/func-app/graphrag/index/verbs/covariates/typing.py b/func-app/graphrag/index/verbs/covariates/typing.py new file mode 100644 index 0000000000..e31cfa4989 --- /dev/null +++ b/func-app/graphrag/index/verbs/covariates/typing.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Covariate' and 'CovariateExtractionResult' models.""" + +from collections.abc import Awaitable, Callable, Iterable +from dataclasses import dataclass +from typing import Any + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + + +@dataclass +class Covariate: + """Covariate class definition.""" + + covariate_type: str | None = None + subject_id: str | None = None + subject_type: str | None = None + object_id: str | None = None + object_type: str | None = None + type: str | None = None + status: str | None = None + start_date: str | None = None + end_date: str | None = None + description: str | None = None + source_text: list[str] | None = None + doc_id: str | None = None + record_id: int | None = None + id: str | None = None + + +@dataclass +class CovariateExtractionResult: + """Covariate extraction result class definition.""" + + covariate_data: list[Covariate] + + +CovariateExtractStrategy = Callable[ + [ + Iterable[str], + list[str], + dict[str, str], + VerbCallbacks, + PipelineCache, + dict[str, Any], + ], + Awaitable[CovariateExtractionResult], +] diff --git a/func-app/graphrag/index/verbs/entities/__init__.py b/func-app/graphrag/index/verbs/entities/__init__.py new file mode 100644 index 0000000000..2f55d710e9 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine entities package root.""" + +from .extraction import entity_extract +from .summarize import summarize_descriptions + +__all__ = ["entity_extract", "summarize_descriptions"] diff --git a/func-app/graphrag/index/verbs/entities/extraction/__init__.py b/func-app/graphrag/index/verbs/entities/extraction/__init__.py new file mode 100644 index 0000000000..46e6d54581 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine entities extraction package root.""" + +from .entity_extract import ExtractEntityStrategyType, entity_extract + +__all__ = ["ExtractEntityStrategyType", "entity_extract"] diff --git a/func-app/graphrag/index/verbs/entities/extraction/entity_extract.py b/func-app/graphrag/index/verbs/entities/extraction/entity_extract.py new file mode 100644 index 0000000000..4e961f674d --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/entity_extract.py @@ -0,0 +1,202 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing entity_extract methods.""" + +import logging +from enum import Enum +from typing import Any, cast + +import pandas as pd +from datashaper import ( + AsyncType, + TableContainer, + VerbCallbacks, + VerbInput, + derive_from_rows, + verb, +) + +from graphrag.index.bootstrap import bootstrap +from graphrag.index.cache import PipelineCache + +from .strategies.typing import Document, EntityExtractStrategy + +log = logging.getLogger(__name__) + + +class ExtractEntityStrategyType(str, Enum): + """ExtractEntityStrategyType class definition.""" + + graph_intelligence = "graph_intelligence" + graph_intelligence_json = "graph_intelligence_json" + nltk = "nltk" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] + + +@verb(name="entity_extract") +async def entity_extract( + input: VerbInput, + cache: PipelineCache, + callbacks: VerbCallbacks, + column: str, + id_column: str, + to: str, + strategy: dict[str, Any] | None, + graph_to: str | None = None, + async_mode: AsyncType = AsyncType.AsyncIO, + entity_types=DEFAULT_ENTITY_TYPES, + **kwargs, +) -> TableContainer: + """ + Extract entities from a piece of text. + + ## Usage + ### json + ```json + { + "verb": "entity_extract", + "args": { + "column": "the_document_text_column_to_extract_entities_from", /* In general this will be your document text column */ + "id_column": "the_column_with_the_unique_id_for_each_row", /* In general this will be your document id */ + "to": "the_column_to_output_the_entities_to", /* This will be a list[dict[str, Any]] a list of entities, with a name, and additional attributes */ + "graph_to": "the_column_to_output_the_graphml_to", /* Optional: This will be a graphml graph in string form which represents the entities and their relationships */ + "strategy": {...} , see strategies section below + "entity_types": ["list", "of", "entity", "types", "to", "extract"] /* Optional: This will limit the entity types extracted, default: ["organization", "person", "geo", "event"] */ + "summarize_descriptions" : true | false /* Optional: This will summarize the descriptions of the entities and relationships, default: true */ + } + } + ``` + ### yaml + ```yaml + verb: entity_extract + args: + column: the_document_text_column_to_extract_entities_from + id_column: the_column_with_the_unique_id_for_each_row + to: the_column_to_output_the_entities_to + graph_to: the_column_to_output_the_graphml_to + strategy: , see strategies section below + summarize_descriptions: true | false /* Optional: This will summarize the descriptions of the entities and relationships, default: true */ + entity_types: + - list + - of + - entity + - types + - to + - extract + ``` + + ## Strategies + The entity extract verb uses a strategy to extract entities from a document. The strategy is a json object which defines the strategy to use. The following strategies are available: + + ### graph_intelligence + This strategy uses the [graph_intelligence] library to extract entities from a document. In particular it uses a LLM to extract entities from a piece of text. The strategy config is as follows: + + ```yml + strategy: + type: graph_intelligence + extraction_prompt: !include ./entity_extraction_prompt.txt # Optional, the prompt to use for extraction + completion_delimiter: "<|COMPLETE|>" # Optional, the delimiter to use for the LLM to mark completion + tuple_delimiter: "<|>" # Optional, the delimiter to use for the LLM to mark a tuple + record_delimiter: "##" # Optional, the delimiter to use for the LLM to mark a record + + prechunked: true | false # Optional, If the document is already chunked beforehand, otherwise this will chunk the document into smaller bits. default: false + encoding_name: cl100k_base # Optional, The encoding to use for the LLM, if not already prechunked, default: cl100k_base + chunk_size: 1000 # Optional ,The chunk size to use for the LLM, if not already prechunked, default: 1200 + chunk_overlap: 100 # Optional, The chunk overlap to use for the LLM, if not already prechunked, default: 100 + + llm: # The configuration for the LLM + type: openai # the type of llm to use, available options are: openai, azure, openai_chat, azure_openai_chat. The last two being chat based LLMs. + api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai + model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai + max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai + organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai + + # if using azure flavor + api_base: !ENV ${GRAPHRAG_OPENAI_API_BASE} # The api base to use for azure + api_version: !ENV ${GRAPHRAG_OPENAI_API_VERSION} # The api version to use for azure + proxy: !ENV ${GRAPHRAG_OPENAI_PROXY} # The proxy to use for azure + + ``` + + ### nltk + This strategy uses the [nltk] library to extract entities from a document. In particular it uses a nltk to extract entities from a piece of text. The strategy config is as follows: + ```yml + strategy: + type: nltk + ``` + """ + log.debug("entity_extract strategy=%s", strategy) + if entity_types is None: + entity_types = DEFAULT_ENTITY_TYPES + output = cast(pd.DataFrame, input.get_input()) + strategy = strategy or {} + strategy_exec = _load_strategy( + strategy.get("type", ExtractEntityStrategyType.graph_intelligence) + ) + strategy_config = {**strategy} + + num_started = 0 + + async def run_strategy(row): + nonlocal num_started + text = row[column] + id = row[id_column] + result = await strategy_exec( + [Document(text=text, id=id)], + entity_types, + callbacks, + cache, + strategy_config, + ) + num_started += 1 + return [result.entities, result.graphml_graph] + + results = await derive_from_rows( + output, + run_strategy, + callbacks, + scheduling_type=async_mode, + num_threads=kwargs.get("num_threads", 4), + ) + + to_result = [] + graph_to_result = [] + for result in results: + if result: + to_result.append(result[0]) + graph_to_result.append(result[1]) + else: + to_result.append(None) + graph_to_result.append(None) + + output[to] = to_result + if graph_to is not None: + output[graph_to] = graph_to_result + + return TableContainer(table=output.reset_index(drop=True)) + + +def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy: + """Load strategy method definition.""" + match strategy_type: + case ExtractEntityStrategyType.graph_intelligence: + from .strategies.graph_intelligence import run_gi + + return run_gi + + case ExtractEntityStrategyType.nltk: + bootstrap() + # dynamically import nltk strategy to avoid dependency if not used + from .strategies.nltk import run as run_nltk + + return run_nltk + case _: + msg = f"Unknown strategy: {strategy_type}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/__init__.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/__init__.py new file mode 100644 index 0000000000..f5cc17d750 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine entities extraction strategies package root.""" diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/__init__.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/__init__.py new file mode 100644 index 0000000000..083c0e4112 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph intelligence package root.""" + +from .run_graph_intelligence import run_gi + +__all__ = ["run_gi"] diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/defaults.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/defaults.py new file mode 100644 index 0000000000..237e6657c8 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/defaults.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing some default responses.""" + +from graphrag.config.enums import LLMType + +MOCK_LLM_RESPONSES = [ + """ + ("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company) + ## + ("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A) + ## + ("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A) + ## + ("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2) + ## + ("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1)) + """.strip() +] + +DEFAULT_LLM_CONFIG = { + "type": LLMType.StaticResponse, + "responses": MOCK_LLM_RESPONSES, +} diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/run_graph_intelligence.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/run_graph_intelligence.py new file mode 100644 index 0000000000..0628487983 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/graph_intelligence/run_graph_intelligence.py @@ -0,0 +1,142 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_gi, run_extract_entities and _create_text_splitter methods to run graph intelligence.""" + +import networkx as nx +from datashaper import VerbCallbacks + +import graphrag.config.defaults as defs +from graphrag.config.enums import LLMType +from graphrag.index.cache import PipelineCache +from graphrag.index.graph.extractors.graph import GraphExtractor +from graphrag.index.llm import load_llm +from graphrag.index.text_splitting import ( + NoopTextSplitter, + TextSplitter, + TokenTextSplitter, +) +from graphrag.index.verbs.entities.extraction.strategies.typing import ( + Document, + EntityExtractionResult, + EntityTypes, + StrategyConfig, +) +from graphrag.llm import CompletionLLM + +from .defaults import DEFAULT_LLM_CONFIG + + +async def run_gi( + docs: list[Document], + entity_types: EntityTypes, + reporter: VerbCallbacks, + pipeline_cache: PipelineCache, + args: StrategyConfig, +) -> EntityExtractionResult: + """Run the graph intelligence entity extraction strategy.""" + llm_config = args.get("llm", DEFAULT_LLM_CONFIG) + llm_type = llm_config.get("type", LLMType.StaticResponse) + llm = load_llm("entity_extraction", llm_type, reporter, pipeline_cache, llm_config) + return await run_extract_entities(llm, docs, entity_types, reporter, args) + + +async def run_extract_entities( + llm: CompletionLLM, + docs: list[Document], + entity_types: EntityTypes, + reporter: VerbCallbacks | None, + args: StrategyConfig, +) -> EntityExtractionResult: + """Run the entity extraction chain.""" + encoding_name = args.get("encoding_name", "cl100k_base") + + # Chunking Arguments + prechunked = args.get("prechunked", False) + chunk_size = args.get("chunk_size", defs.CHUNK_SIZE) + chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP) + + # Extraction Arguments + tuple_delimiter = args.get("tuple_delimiter", None) + record_delimiter = args.get("record_delimiter", None) + completion_delimiter = args.get("completion_delimiter", None) + extraction_prompt = args.get("extraction_prompt", None) + encoding_model = args.get("encoding_name", None) + max_gleanings = args.get("max_gleanings", defs.ENTITY_EXTRACTION_MAX_GLEANINGS) + + # note: We're not using UnipartiteGraphChain.from_params + # because we want to pass "timeout" to the llm_kwargs + text_splitter = _create_text_splitter( + prechunked, chunk_size, chunk_overlap, encoding_name + ) + + extractor = GraphExtractor( + llm_invoker=llm, + prompt=extraction_prompt, + encoding_model=encoding_model, + max_gleanings=max_gleanings, + on_error=lambda e, s, d: ( + reporter.error("Entity Extraction Error", e, s, d) if reporter else None + ), + ) + text_list = [doc.text.strip() for doc in docs] + + # If it's not pre-chunked, then re-chunk the input + if not prechunked: + text_list = text_splitter.split_text("\n".join(text_list)) + + results = await extractor( + list(text_list), + { + "entity_types": entity_types, + "tuple_delimiter": tuple_delimiter, + "record_delimiter": record_delimiter, + "completion_delimiter": completion_delimiter, + }, + ) + + graph = results.output + # Map the "source_id" back to the "id" field + for _, node in graph.nodes(data=True): # type: ignore + if node is not None: + node["source_id"] = ",".join( + docs[int(id)].id for id in node["source_id"].split(",") + ) + + for _, _, edge in graph.edges(data=True): # type: ignore + if edge is not None: + edge["source_id"] = ",".join( + docs[int(id)].id for id in edge["source_id"].split(",") + ) + + entities = [ + ({"name": item[0], **(item[1] or {})}) + for item in graph.nodes(data=True) + if item is not None + ] + + graph_data = "".join(nx.generate_graphml(graph)) + return EntityExtractionResult(entities, graph_data) + + +def _create_text_splitter( + prechunked: bool, chunk_size: int, chunk_overlap: int, encoding_name: str +) -> TextSplitter: + """Create a text splitter for the extraction chain. + + Args: + - prechunked - Whether the text is already chunked + - chunk_size - The size of each chunk + - chunk_overlap - The overlap between chunks + - encoding_name - The name of the encoding to use + Returns: + - output - A text splitter + """ + if prechunked: + return NoopTextSplitter() + + return TokenTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + encoding_name=encoding_name, + ) diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/nltk.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/nltk.py new file mode 100644 index 0000000000..48d4dae4ca --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/nltk.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run method definition.""" + +import networkx as nx +import nltk +from datashaper import VerbCallbacks +from nltk.corpus import words + +from graphrag.index.cache import PipelineCache + +from .typing import Document, EntityExtractionResult, EntityTypes, StrategyConfig + +# Need to do this cause we're potentially multithreading, and nltk doesn't like that +words.ensure_loaded() + + +async def run( # noqa RUF029 async is required for interface + docs: list[Document], + entity_types: EntityTypes, + reporter: VerbCallbacks, # noqa ARG001 + pipeline_cache: PipelineCache, # noqa ARG001 + args: StrategyConfig, # noqa ARG001 +) -> EntityExtractionResult: + """Run method definition.""" + entity_map = {} + graph = nx.Graph() + for doc in docs: + connected_entities = [] + for chunk in nltk.ne_chunk(nltk.pos_tag(nltk.word_tokenize(doc.text))): + if hasattr(chunk, "label"): + entity_type = chunk.label().lower() + if entity_type in entity_types: + name = (" ".join(c[0] for c in chunk)).upper() + connected_entities.append(name) + if name not in entity_map: + entity_map[name] = entity_type + graph.add_node( + name, type=entity_type, description=name, source_id=doc.id + ) + + # connect the entities if they appear in the same document + if len(connected_entities) > 1: + for i in range(len(connected_entities)): + for j in range(i + 1, len(connected_entities)): + description = f"{connected_entities[i]} -> {connected_entities[j]}" + graph.add_edge( + connected_entities[i], + connected_entities[j], + description=description, + source_id=doc.id, + ) + + return EntityExtractionResult( + entities=[ + {"type": entity_type, "name": name} + for name, entity_type in entity_map.items() + ], + graphml_graph="".join(nx.generate_graphml(graph)), + ) diff --git a/func-app/graphrag/index/verbs/entities/extraction/strategies/typing.py b/func-app/graphrag/index/verbs/entities/extraction/strategies/typing.py new file mode 100644 index 0000000000..45d3f1b80e --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/extraction/strategies/typing.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Document' and 'EntityExtractionResult' models.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + +ExtractedEntity = dict[str, Any] +StrategyConfig = dict[str, Any] +EntityTypes = list[str] + + +@dataclass +class Document: + """Document class definition.""" + + text: str + id: str + + +@dataclass +class EntityExtractionResult: + """Entity extraction result class definition.""" + + entities: list[ExtractedEntity] + graphml_graph: str | None + + +EntityExtractStrategy = Callable[ + [ + list[Document], + EntityTypes, + VerbCallbacks, + PipelineCache, + StrategyConfig, + ], + Awaitable[EntityExtractionResult], +] diff --git a/func-app/graphrag/index/verbs/entities/summarize/__init__.py b/func-app/graphrag/index/verbs/entities/summarize/__init__.py new file mode 100644 index 0000000000..d7e9a5d93a --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Root package for entity summarization.""" + +from .description_summarize import SummarizeStrategyType, summarize_descriptions + +__all__ = ["SummarizeStrategyType", "summarize_descriptions"] diff --git a/func-app/graphrag/index/verbs/entities/summarize/description_summarize.py b/func-app/graphrag/index/verbs/entities/summarize/description_summarize.py new file mode 100644 index 0000000000..5b7feb4184 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/description_summarize.py @@ -0,0 +1,207 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the summarize_descriptions verb.""" + +import asyncio +import logging +from enum import Enum +from typing import Any, NamedTuple, cast + +import networkx as nx +import pandas as pd +from datashaper import ( + ProgressTicker, + TableContainer, + VerbCallbacks, + VerbInput, + progress_ticker, + verb, +) + +from graphrag.index.cache import PipelineCache +from graphrag.index.utils import load_graph + +from .strategies.typing import SummarizationStrategy + +log = logging.getLogger(__name__) + + +class DescriptionSummarizeRow(NamedTuple): + """DescriptionSummarizeRow class definition.""" + + graph: Any + + +class SummarizeStrategyType(str, Enum): + """SummarizeStrategyType class definition.""" + + graph_intelligence = "graph_intelligence" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="summarize_descriptions") +async def summarize_descriptions( + input: VerbInput, + cache: PipelineCache, + callbacks: VerbCallbacks, + column: str, + to: str, + strategy: dict[str, Any] | None = None, + **kwargs, +) -> TableContainer: + """ + Summarize entity and relationship descriptions from an entity graph. + + ## Usage + + To turn this feature ON please set the environment variable `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_ENABLED=True`. + + ### json + + ```json + { + "verb": "", + "args": { + "column": "the_document_text_column_to_extract_descriptions_from", /* Required: This will be a graphml graph in string form which represents the entities and their relationships */ + "to": "the_column_to_output_the_summarized_descriptions_to", /* Required: This will be a graphml graph in string form which represents the entities and their relationships after being summarized */ + "strategy": {...} , see strategies section below + } + } + ``` + + ### yaml + + ```yaml + verb: entity_extract + args: + column: the_document_text_column_to_extract_descriptions_from + to: the_column_to_output_the_summarized_descriptions_to + strategy: , see strategies section below + ``` + + ## Strategies + + The summarize descriptions verb uses a strategy to summarize descriptions for entities. The strategy is a json object which defines the strategy to use. The following strategies are available: + + ### graph_intelligence + + This strategy uses the [graph_intelligence] library to summarize descriptions for entities. The strategy config is as follows: + + ```yml + strategy: + type: graph_intelligence + summarize_prompt: # Optional, the prompt to use for extraction + + + llm: # The configuration for the LLM + type: openai # the type of llm to use, available options are: openai, azure, openai_chat, azure_openai_chat. The last two being chat based LLMs. + api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai + model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai + max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai + organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai + + # if using azure flavor + api_base: !ENV ${GRAPHRAG_OPENAI_API_BASE} # The api base to use for azure + api_version: !ENV ${GRAPHRAG_OPENAI_API_VERSION} # The api version to use for azure + proxy: !ENV ${GRAPHRAG_OPENAI_PROXY} # The proxy to use for azure + ``` + """ + log.debug("summarize_descriptions strategy=%s", strategy) + output = cast(pd.DataFrame, input.get_input()) + strategy = strategy or {} + strategy_exec = load_strategy( + strategy.get("type", SummarizeStrategyType.graph_intelligence) + ) + strategy_config = {**strategy} + + async def get_resolved_entities(row, semaphore: asyncio.Semaphore): + graph: nx.Graph = load_graph(cast(str | nx.Graph, getattr(row, column))) + + ticker_length = len(graph.nodes) + len(graph.edges) + + ticker = progress_ticker(callbacks.progress, ticker_length) + + futures = [ + do_summarize_descriptions( + node, + sorted(set(graph.nodes[node].get("description", "").split("\n"))), + ticker, + semaphore, + ) + for node in graph.nodes() + ] + futures += [ + do_summarize_descriptions( + edge, + sorted(set(graph.edges[edge].get("description", "").split("\n"))), + ticker, + semaphore, + ) + for edge in graph.edges() + ] + + results = await asyncio.gather(*futures) + + for result in results: + graph_item = result.items + if isinstance(graph_item, str) and graph_item in graph.nodes(): + graph.nodes[graph_item]["description"] = result.description + elif isinstance(graph_item, tuple) and graph_item in graph.edges(): + graph.edges[graph_item]["description"] = result.description + + return DescriptionSummarizeRow( + graph="\n".join(nx.generate_graphml(graph)), + ) + + async def do_summarize_descriptions( + graph_item: str | tuple[str, str], + descriptions: list[str], + ticker: ProgressTicker, + semaphore: asyncio.Semaphore, + ): + async with semaphore: + results = await strategy_exec( + graph_item, + descriptions, + callbacks, + cache, + strategy_config, + ) + ticker(1) + return results + + # Graph is always on row 0, so here a derive from rows does not work + # This iteration will only happen once, but avoids hardcoding a iloc[0] + # Since parallelization is at graph level (nodes and edges), we can't use + # the parallelization of the derive_from_rows + semaphore = asyncio.Semaphore(kwargs.get("num_threads", 4)) + + results = [ + await get_resolved_entities(row, semaphore) for row in output.itertuples() + ] + + to_result = [] + + for result in results: + if result: + to_result.append(result.graph) + else: + to_result.append(None) + output[to] = to_result + return TableContainer(table=output) + + +def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy: + """Load strategy method definition.""" + match strategy_type: + case SummarizeStrategyType.graph_intelligence: + from .strategies.graph_intelligence import run as run_gi + + return run_gi + case _: + msg = f"Unknown strategy: {strategy_type}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/entities/summarize/strategies/__init__.py b/func-app/graphrag/index/verbs/entities/summarize/strategies/__init__.py new file mode 100644 index 0000000000..28c398e6ac --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/strategies/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Indexing Engine - Summarization Strategies Package.""" + +from .typing import SummarizationStrategy + +__all__ = ["SummarizationStrategy"] diff --git a/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/__init__.py b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/__init__.py new file mode 100644 index 0000000000..a98d9406cb --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Entity summarization graph intelligence package root.""" + +from .run_graph_intelligence import run + +__all__ = ["run"] diff --git a/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/defaults.py b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/defaults.py new file mode 100644 index 0000000000..8ac42aa13d --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/defaults.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing some default responses.""" + +from graphrag.config.enums import LLMType + +MOCK_LLM_RESPONSES = [ + """ + This is a MOCK response for the LLM. It is summarized! + """.strip() +] + +DEFAULT_LLM_CONFIG = { + "type": LLMType.StaticResponse, + "responses": MOCK_LLM_RESPONSES, +} diff --git a/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/run_graph_intelligence.py b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/run_graph_intelligence.py new file mode 100644 index 0000000000..57a1ecd218 --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/strategies/graph_intelligence/run_graph_intelligence.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_gi, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence.""" + +from datashaper import VerbCallbacks + +from graphrag.config.enums import LLMType +from graphrag.index.cache import PipelineCache +from graphrag.index.graph.extractors.summarize import SummarizeExtractor +from graphrag.index.llm import load_llm +from graphrag.index.verbs.entities.summarize.strategies.typing import ( + StrategyConfig, + SummarizedDescriptionResult, +) +from graphrag.llm import CompletionLLM + +from .defaults import DEFAULT_LLM_CONFIG + + +async def run( + described_items: str | tuple[str, str], + descriptions: list[str], + reporter: VerbCallbacks, + pipeline_cache: PipelineCache, + args: StrategyConfig, +) -> SummarizedDescriptionResult: + """Run the graph intelligence entity extraction strategy.""" + llm_config = args.get("llm", DEFAULT_LLM_CONFIG) + llm_type = llm_config.get("type", LLMType.StaticResponse) + llm = load_llm( + "summarize_descriptions", llm_type, reporter, pipeline_cache, llm_config + ) + return await run_summarize_descriptions( + llm, described_items, descriptions, reporter, args + ) + + +async def run_summarize_descriptions( + llm: CompletionLLM, + items: str | tuple[str, str], + descriptions: list[str], + reporter: VerbCallbacks, + args: StrategyConfig, +) -> SummarizedDescriptionResult: + """Run the entity extraction chain.""" + # Extraction Arguments + summarize_prompt = args.get("summarize_prompt", None) + entity_name_key = args.get("entity_name_key", "entity_name") + input_descriptions_key = args.get("input_descriptions_key", "description_list") + max_tokens = args.get("max_tokens", None) + + extractor = SummarizeExtractor( + llm_invoker=llm, + summarization_prompt=summarize_prompt, + entity_name_key=entity_name_key, + input_descriptions_key=input_descriptions_key, + on_error=lambda e, stack, details: ( + reporter.error("Entity Extraction Error", e, stack, details) + if reporter + else None + ), + max_summary_length=args.get("max_summary_length", None), + max_input_tokens=max_tokens, + ) + + result = await extractor(items=items, descriptions=descriptions) + return SummarizedDescriptionResult( + items=result.items, description=result.description + ) diff --git a/func-app/graphrag/index/verbs/entities/summarize/strategies/typing.py b/func-app/graphrag/index/verbs/entities/summarize/strategies/typing.py new file mode 100644 index 0000000000..398295031b --- /dev/null +++ b/func-app/graphrag/index/verbs/entities/summarize/strategies/typing.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'SummarizedDescriptionResult' model.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + +StrategyConfig = dict[str, Any] + + +@dataclass +class SummarizedDescriptionResult: + """Entity summarization result class definition.""" + + items: str | tuple[str, str] + description: str + + +SummarizationStrategy = Callable[ + [ + str | tuple[str, str], + list[str], + VerbCallbacks, + PipelineCache, + StrategyConfig, + ], + Awaitable[SummarizedDescriptionResult], +] diff --git a/func-app/graphrag/index/verbs/genid.py b/func-app/graphrag/index/verbs/genid.py new file mode 100644 index 0000000000..019ffc2da0 --- /dev/null +++ b/func-app/graphrag/index/verbs/genid.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing genid method definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.utils import gen_md5_hash + + +@verb(name="genid") +def genid( + input: VerbInput, + to: str, + method: str = "md5_hash", + hash: list[str] = [], # noqa A002 + **_kwargs: dict, +) -> TableContainer: + """ + Generate a unique id for each row in the tabular data. + + ## Usage + ### json + ```json + { + "verb": "genid", + "args": { + "to": "id_output_column_name", /* The name of the column to output the id to */ + "method": "md5_hash", /* The method to use to generate the id */ + "hash": ["list", "of", "column", "names"] /* only if using md5_hash */, + "seed": 034324 /* The random seed to use with UUID */ + } + } + ``` + + ### yaml + ```yaml + verb: genid + args: + to: id_output_column_name + method: md5_hash + hash: + - list + - of + - column + - names + seed: 034324 + ``` + """ + data = cast(pd.DataFrame, input.source.table) + + if method == "md5_hash": + if len(hash) == 0: + msg = 'Must specify the "hash" columns to use md5_hash method' + raise ValueError(msg) + + data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1) + elif method == "increment": + data[to] = data.index + 1 + else: + msg = f"Unknown method {method}" + raise ValueError(msg) + return TableContainer(table=data) diff --git a/func-app/graphrag/index/verbs/graph/__init__.py b/func-app/graphrag/index/verbs/graph/__init__.py new file mode 100644 index 0000000000..5edbdbe530 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph package root.""" + +from .clustering import cluster_graph +from .compute_edge_combined_degree import compute_edge_combined_degree +from .create import DEFAULT_EDGE_ATTRIBUTES, DEFAULT_NODE_ATTRIBUTES, create_graph +from .embed import embed_graph +from .layout import layout_graph +from .merge import merge_graphs +from .report import ( + create_community_reports, + prepare_community_reports, + prepare_community_reports_claims, + prepare_community_reports_edges, + restore_community_hierarchy, +) +from .unpack import unpack_graph + +__all__ = [ + "DEFAULT_EDGE_ATTRIBUTES", + "DEFAULT_NODE_ATTRIBUTES", + "cluster_graph", + "compute_edge_combined_degree", + "create_community_reports", + "create_graph", + "embed_graph", + "layout_graph", + "merge_graphs", + "prepare_community_reports", + "prepare_community_reports_claims", + "prepare_community_reports_edges", + "restore_community_hierarchy", + "unpack_graph", +] diff --git a/func-app/graphrag/index/verbs/graph/clustering/__init__.py b/func-app/graphrag/index/verbs/graph/clustering/__init__.py new file mode 100644 index 0000000000..a5db89bb7f --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/clustering/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph clustering package root.""" + +from .cluster_graph import GraphCommunityStrategyType, cluster_graph + +__all__ = ["GraphCommunityStrategyType", "cluster_graph"] diff --git a/func-app/graphrag/index/verbs/graph/clustering/cluster_graph.py b/func-app/graphrag/index/verbs/graph/clustering/cluster_graph.py new file mode 100644 index 0000000000..0cfb929c63 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/clustering/cluster_graph.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing cluster_graph, apply_clustering and run_layout methods definition.""" + +import logging +from enum import Enum +from random import Random +from typing import Any, cast + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb + +from graphrag.index.utils import gen_uuid, load_graph + +from .typing import Communities +from hashlib import sha256 + +log = logging.getLogger(__name__) + + +@verb(name="cluster_graph") +def cluster_graph( + input: VerbInput, + callbacks: VerbCallbacks, + strategy: dict[str, Any], + column: str, + to: str, + level_to: str | None = None, + **_kwargs, +) -> TableContainer: + """ + Apply a hierarchical clustering algorithm to a graph. The graph is expected to be in graphml format. The verb outputs a new column containing the clustered graph, and a new column containing the level of the graph. + + ## Usage + ```yaml + verb: cluster_graph + args: + column: entity_graph # The name of the column containing the graph, should be a graphml graph + to: clustered_graph # The name of the column to output the clustered graph to + level_to: level # The name of the column to output the level to + strategy: # See strategies section below + ``` + + ## Strategies + The cluster graph verb uses a strategy to cluster the graph. The strategy is a json object which defines the strategy to use. The following strategies are available: + + ### leiden + This strategy uses the leiden algorithm to cluster a graph. The strategy config is as follows: + ```yaml + strategy: + type: leiden + max_cluster_size: 10 # Optional, The max cluster size to use, default: 10 + use_lcc: true # Optional, if the largest connected component should be used with the leiden algorithm, default: true + seed: 0xDEADBEEF # Optional, the seed to use for the leiden algorithm, default: 0xDEADBEEF + levels: [0, 1] # Optional, the levels to output, default: all the levels detected + + ``` + """ + output_df = cast(pd.DataFrame, input.get_input()) + results = output_df[column].apply(lambda graph: run_layout(strategy, graph)) + + community_map_to = "communities" + output_df[community_map_to] = results + + level_to = level_to or f"{to}_level" + output_df[level_to] = output_df.apply( + lambda x: list({level for level, _, _ in x[community_map_to]}), axis=1 + ) + output_df[to] = [None] * len(output_df) + + num_total = len(output_df) + + # Go through each of the rows + graph_level_pairs_column: list[list[tuple[int, str]]] = [] + for _, row in progress_iterable( + output_df.iterrows(), callbacks.progress, num_total + ): + levels = row[level_to] + graph_level_pairs: list[tuple[int, str]] = [] + + # For each of the levels, get the graph and add it to the list + for level in levels: + graph = "\n".join( + nx.generate_graphml( + apply_clustering( + cast(str, row[column]), + cast(Communities, row[community_map_to]), + level, + ) + ) + ) + graph_level_pairs.append((level, graph)) + graph_level_pairs_column.append(graph_level_pairs) + output_df[to] = graph_level_pairs_column + + # explode the list of (level, graph) pairs into separate rows + output_df = output_df.explode(to, ignore_index=True) + + # split the (level, graph) pairs into separate columns + # TODO: There is probably a better way to do this + output_df[[level_to, to]] = pd.DataFrame( + output_df[to].tolist(), index=output_df.index + ) + + # clean up the community map + output_df.drop(columns=[community_map_to], inplace=True) + + return TableContainer(table=output_df) + +def generate_entity_id(candidate: str) -> str: + h=sha256() + h.update(candidate.encode()) + return h.hexdigest() + +# TODO: This should support str | nx.Graph as a graphml param +def apply_clustering( + graphml: str, communities: Communities, level=0 +) -> nx.Graph: + """Apply clustering to a graphml string.""" + + graph = nx.parse_graphml(graphml) + for community_level, community_id, nodes in communities: + if level == community_level: + for node in nodes: + graph.nodes[node]["cluster"] = community_id + graph.nodes[node]["level"] = level + + # add node degree + for node_degree in graph.degree: + graph.nodes[str(node_degree[0])]["degree"] = int(node_degree[1]) + + # Generate a unique ID for each entitiy and incremental record id (a human readable id used as reference in the final report) + for index, node in enumerate(graph.nodes()): + graph.nodes[node]["human_readable_id"] = index + graph.nodes[node]["id"] = generate_entity_id(node) + + # add ids to edges + for index, edge in enumerate(graph.edges()): + graph.edges[edge]["human_readable_id"] = index + graph.edges[edge]["level"] = level + graph.edges[edge]["id"] = generate_entity_id(f"{edge[0]}:{edge[1]}") + + return graph + + +class GraphCommunityStrategyType(str, Enum): + """GraphCommunityStrategyType class definition.""" + + leiden = "leiden" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +def run_layout( + strategy: dict[str, Any], graphml_or_graph: str | nx.Graph +) -> Communities: + """Run layout method definition.""" + graph = load_graph(graphml_or_graph) + if len(graph.nodes) == 0: + log.warning("Graph has no nodes") + return [] + + clusters: dict[int, dict[str, list[str]]] = {} + strategy_type = strategy.get("type", GraphCommunityStrategyType.leiden) + match strategy_type: + case GraphCommunityStrategyType.leiden: + from .strategies.leiden import run as run_leiden + + clusters = run_leiden(graph, strategy) + case _: + msg = f"Unknown clustering strategy {strategy_type}" + raise ValueError(msg) + + results: Communities = [] + for level in clusters: + for cluster_id, nodes in clusters[level].items(): + results.append((level, cluster_id, nodes)) + return results diff --git a/func-app/graphrag/index/verbs/graph/clustering/strategies/__init__.py b/func-app/graphrag/index/verbs/graph/clustering/strategies/__init__.py new file mode 100644 index 0000000000..16a03f12d6 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/clustering/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Graph Clustering Strategies.""" diff --git a/func-app/graphrag/index/verbs/graph/clustering/strategies/leiden.py b/func-app/graphrag/index/verbs/graph/clustering/strategies/leiden.py new file mode 100644 index 0000000000..ffc3688041 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/clustering/strategies/leiden.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _compute_leiden_communities methods definitions.""" + +import logging +from typing import Any + +import networkx as nx +from graspologic.partition import hierarchical_leiden + +from graphrag.index.graph.utils import stable_largest_connected_component + +log = logging.getLogger(__name__) + + +def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, list[str]]]: + """Run method definition.""" + max_cluster_size = args.get("max_cluster_size", 10) + use_lcc = args.get("use_lcc", True) + if args.get("verbose", False): + log.info( + "Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc + ) + + node_id_to_community_map = _compute_leiden_communities( + graph=graph, + max_cluster_size=max_cluster_size, + use_lcc=use_lcc, + seed=args.get("seed", 0xDEADBEEF), + ) + levels = args.get("levels") + + # If they don't pass in levels, use them all + if levels is None: + levels = sorted(node_id_to_community_map.keys()) + + results_by_level: dict[int, dict[str, list[str]]] = {} + for level in levels: + result = {} + results_by_level[level] = result + for node_id, raw_community_id in node_id_to_community_map[level].items(): + community_id = str(raw_community_id) + if community_id not in result: + result[community_id] = [] + result[community_id].append(node_id) + return results_by_level + + +# Taken from graph_intelligence & adapted +def _compute_leiden_communities( + graph: nx.Graph | nx.DiGraph, + max_cluster_size: int, + use_lcc: bool, + seed=0xDEADBEEF, +) -> dict[int, dict[str, int]]: + """Return Leiden root communities.""" + if use_lcc: + graph = stable_largest_connected_component(graph) + + community_mapping = hierarchical_leiden( + graph, max_cluster_size=max_cluster_size, random_seed=seed + ) + results: dict[int, dict[str, int]] = {} + for partition in community_mapping: + results[partition.level] = results.get(partition.level, {}) + results[partition.level][partition.node] = partition.cluster + + return results diff --git a/func-app/graphrag/index/verbs/graph/clustering/typing.py b/func-app/graphrag/index/verbs/graph/clustering/typing.py new file mode 100644 index 0000000000..4d6fc7e601 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/clustering/typing.py @@ -0,0 +1,6 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing Communities list definition.""" + +Communities = list[tuple[int, str, list[str]]] diff --git a/func-app/graphrag/index/verbs/graph/compute_edge_combined_degree.py b/func-app/graphrag/index/verbs/graph/compute_edge_combined_degree.py new file mode 100644 index 0000000000..1f2dd71972 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/compute_edge_combined_degree.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.utils.ds_util import get_required_input_table + + +@verb(name="compute_edge_combined_degree") +def compute_edge_combined_degree( + input: VerbInput, + to: str = "rank", + node_name_column: str = "title", + node_degree_column: str = "degree", + edge_source_column: str = "source", + edge_target_column: str = "target", + **_kwargs, +) -> TableContainer: + """ + Compute the combined degree for each edge in a graph. + + Inputs Tables: + - input: The edge table + - nodes: The nodes table. + + Args: + - to: The name of the column to output the combined degree to. Default="rank" + """ + edge_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()) + if to in edge_df.columns: + return TableContainer(table=edge_df) + node_degree_df = _get_node_degree_table(input, node_name_column, node_degree_column) + + def join_to_degree(df: pd.DataFrame, column: str) -> pd.DataFrame: + degree_column = _degree_colname(column) + result = df.merge( + node_degree_df.rename( + columns={node_name_column: column, node_degree_column: degree_column} + ), + on=column, + how="left", + ) + result[degree_column] = result[degree_column].fillna(0) + return result + + edge_df = join_to_degree(edge_df, edge_source_column) + edge_df = join_to_degree(edge_df, edge_target_column) + edge_df[to] = ( + edge_df[_degree_colname(edge_source_column)] + + edge_df[_degree_colname(edge_target_column)] + ) + + return TableContainer(table=edge_df) + + +def _degree_colname(column: str) -> str: + return f"{column}_degree" + + +def _get_node_degree_table( + input: VerbInput, node_name_column: str, node_degree_column: str +) -> pd.DataFrame: + nodes_container = get_required_input_table(input, "nodes") + nodes = cast(pd.DataFrame, nodes_container.table) + return cast(pd.DataFrame, nodes[[node_name_column, node_degree_column]]) diff --git a/func-app/graphrag/index/verbs/graph/create.py b/func-app/graphrag/index/verbs/graph/create.py new file mode 100644 index 0000000000..eaf06284ef --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/create.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +from typing import Any + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb + +from graphrag.index.utils import clean_str + +DEFAULT_NODE_ATTRIBUTES = ["label", "type", "id", "name", "description", "community"] +DEFAULT_EDGE_ATTRIBUTES = ["label", "type", "name", "source", "target"] + + +@verb(name="create_graph") +def create_graph( + input: VerbInput, + callbacks: VerbCallbacks, + to: str, + type: str, # noqa A002 + graph_type: str = "undirected", + **kwargs, +) -> TableContainer: + """ + Create a graph from a dataframe. The verb outputs a new column containing the graph. + + > Note: This will roll up all rows into a single graph. + + ## Usage + ```yaml + verb: create_graph + args: + type: node # The type of graph to create, one of: node, edge + to: # The name of the column to output the graph to, this will be a graphml graph + attributes: # The attributes for the nodes / edges + # If using the node type, the following attributes are required: + id: + + # If using the edge type, the following attributes are required: + source: + target: + + # Other attributes can be added as follows: + : + ... for each attribute + ``` + """ + if type != "node" and type != "edge": + msg = f"Unknown type {type}" + raise ValueError(msg) + + input_df = input.get_input() + num_total = len(input_df) + out_graph: nx.Graph = _create_nx_graph(graph_type) + + in_attributes = ( + _get_node_attributes(kwargs) if type == "node" else _get_edge_attributes(kwargs) + ) + + # At this point, _get_node_attributes and _get_edge_attributes have already validated + id_col = in_attributes.get( + "id", in_attributes.get("label", in_attributes.get("name", None)) + ) + source_col = in_attributes.get("source", None) + target_col = in_attributes.get("target", None) + + for _, row in progress_iterable(input_df.iterrows(), callbacks.progress, num_total): + item_attributes = { + clean_str(key): _clean_value(row[value]) + for key, value in in_attributes.items() + if value in row + } + if type == "node": + id = clean_str(row[id_col]) + out_graph.add_node(id, **item_attributes) + elif type == "edge": + source = clean_str(row[source_col]) + target = clean_str(row[target_col]) + out_graph.add_edge(source, target, **item_attributes) + + graphml_string = "".join(nx.generate_graphml(out_graph)) + output_df = pd.DataFrame([{to: graphml_string}]) + return TableContainer(table=output_df) + + +def _clean_value(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return clean_str(value) + + msg = f"Value must be a string or None, got {type(value)}" + raise TypeError(msg) + + +def _get_node_attributes(args: dict[str, Any]) -> dict[str, Any]: + mapping = _get_attribute_column_mapping( + args.get("attributes", DEFAULT_NODE_ATTRIBUTES) + ) + if "id" not in mapping and "label" not in mapping and "name" not in mapping: + msg = "You must specify an id, label, or name column in the node attributes" + raise ValueError(msg) + return mapping + + +def _get_edge_attributes(args: dict[str, Any]) -> dict[str, Any]: + mapping = _get_attribute_column_mapping( + args.get("attributes", DEFAULT_EDGE_ATTRIBUTES) + ) + if "source" not in mapping or "target" not in mapping: + msg = "You must specify a source and target column in the edge attributes" + raise ValueError(msg) + return mapping + + +def _get_attribute_column_mapping( + in_attributes: dict[str, Any] | list[str], +) -> dict[str, str]: + # Its already a attribute: column dict + if isinstance(in_attributes, dict): + return { + **in_attributes, + } + + return {attrib: attrib for attrib in in_attributes} + + +def _create_nx_graph(graph_type: str) -> nx.Graph: + if graph_type == "directed": + return nx.DiGraph() + + return nx.Graph() diff --git a/func-app/graphrag/index/verbs/graph/embed/__init__.py b/func-app/graphrag/index/verbs/graph/embed/__init__.py new file mode 100644 index 0000000000..4ca8168c3c --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/embed/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph embed package root.""" + +from .embed_graph import EmbedGraphStrategyType, embed_graph + +__all__ = ["EmbedGraphStrategyType", "embed_graph"] diff --git a/func-app/graphrag/index/verbs/graph/embed/embed_graph.py b/func-app/graphrag/index/verbs/graph/embed/embed_graph.py new file mode 100644 index 0000000000..8691d343f0 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/embed/embed_graph.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing embed_graph and run_embeddings methods definition.""" + +from enum import Enum +from typing import Any, cast + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, derive_from_rows, verb + +from graphrag.index.utils import load_graph + +from .typing import NodeEmbeddings + + +class EmbedGraphStrategyType(str, Enum): + """EmbedGraphStrategyType class definition.""" + + node2vec = "node2vec" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="embed_graph") +async def embed_graph( + input: VerbInput, + callbacks: VerbCallbacks, + strategy: dict[str, Any], + column: str, + to: str, + **kwargs, +) -> TableContainer: + """ + Embed a graph into a vector space. The graph is expected to be in graphml format. The verb outputs a new column containing a mapping between node_id and vector. + + ## Usage + ```yaml + verb: embed_graph + args: + column: clustered_graph # The name of the column containing the graph, should be a graphml graph + to: embeddings # The name of the column to output the embeddings to + strategy: # See strategies section below + ``` + + ## Strategies + The embed_graph verb uses a strategy to embed the graph. The strategy is an object which defines the strategy to use. The following strategies are available: + + ### node2vec + This strategy uses the node2vec algorithm to embed a graph. The strategy config is as follows: + + ```yaml + strategy: + type: node2vec + dimensions: 1536 # Optional, The number of dimensions to use for the embedding, default: 1536 + num_walks: 10 # Optional, The number of walks to use for the embedding, default: 10 + walk_length: 40 # Optional, The walk length to use for the embedding, default: 40 + window_size: 2 # Optional, The window size to use for the embedding, default: 2 + iterations: 3 # Optional, The number of iterations to use for the embedding, default: 3 + random_seed: 86 # Optional, The random seed to use for the embedding, default: 86 + ``` + """ + output_df = cast(pd.DataFrame, input.get_input()) + + strategy_type = strategy.get("type", EmbedGraphStrategyType.node2vec) + strategy_args = {**strategy} + + async def run_strategy(row): # noqa RUF029 async is required for interface + return run_embeddings(strategy_type, cast(Any, row[column]), strategy_args) + + results = await derive_from_rows( + output_df, + run_strategy, + callbacks=callbacks, + num_threads=kwargs.get("num_threads", None), + ) + output_df[to] = list(results) + return TableContainer(table=output_df) + + +def run_embeddings( + strategy: EmbedGraphStrategyType, + graphml_or_graph: str | nx.Graph, + args: dict[str, Any], +) -> NodeEmbeddings: + """Run embeddings method definition.""" + graph = load_graph(graphml_or_graph) + match strategy: + case EmbedGraphStrategyType.node2vec: + from .strategies.node_2_vec import run as run_node_2_vec + + return run_node_2_vec(graph, args) + case _: + msg = f"Unknown strategy {strategy}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/graph/embed/strategies/__init__.py b/func-app/graphrag/index/verbs/graph/embed/strategies/__init__.py new file mode 100644 index 0000000000..ef85198eb7 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/embed/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Text Embedding strategies.""" diff --git a/func-app/graphrag/index/verbs/graph/embed/strategies/node_2_vec.py b/func-app/graphrag/index/verbs/graph/embed/strategies/node_2_vec.py new file mode 100644 index 0000000000..eb329519ed --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/embed/strategies/node_2_vec.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run method definition.""" + +from typing import Any + +import networkx as nx + +from graphrag.index.graph.embedding import embed_nod2vec +from graphrag.index.graph.utils import stable_largest_connected_component +from graphrag.index.verbs.graph.embed.typing import NodeEmbeddings + + +def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings: + """Run method definition.""" + if args.get("use_lcc", True): + graph = stable_largest_connected_component(graph) + + # create graph embedding using node2vec + embeddings = embed_nod2vec( + graph=graph, + dimensions=args.get("dimensions", 1536), + num_walks=args.get("num_walks", 10), + walk_length=args.get("walk_length", 40), + window_size=args.get("window_size", 2), + iterations=args.get("iterations", 3), + random_seed=args.get("random_seed", 86), + ) + + pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True) + sorted_pairs = sorted(pairs, key=lambda x: x[0]) + + return dict(sorted_pairs) diff --git a/func-app/graphrag/index/verbs/graph/embed/typing.py b/func-app/graphrag/index/verbs/graph/embed/typing.py new file mode 100644 index 0000000000..fea792c9b1 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/embed/typing.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing different lists and dictionaries.""" + +# Use this for now instead of a wrapper +from typing import Any + +NodeList = list[str] +EmbeddingList = list[Any] +NodeEmbeddings = dict[str, list[float]] +"""Label -> Embedding""" diff --git a/func-app/graphrag/index/verbs/graph/layout/__init__.py b/func-app/graphrag/index/verbs/graph/layout/__init__.py new file mode 100644 index 0000000000..74584f83ed --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/layout/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph layout package root.""" + +from .layout_graph import layout_graph + +__all__ = ["layout_graph"] diff --git a/func-app/graphrag/index/verbs/graph/layout/layout_graph.py b/func-app/graphrag/index/verbs/graph/layout/layout_graph.py new file mode 100644 index 0000000000..e1b55b1183 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/layout/layout_graph.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing layout_graph, _run_layout and _apply_layout_to_graph methods definition.""" + +from enum import Enum +from typing import Any, cast + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_callback, verb + +from graphrag.index.graph.visualization import GraphLayout +from graphrag.index.utils import load_graph +from graphrag.index.verbs.graph.embed.typing import NodeEmbeddings + + +class LayoutGraphStrategyType(str, Enum): + """LayoutGraphStrategyType class definition.""" + + umap = "umap" + zero = "zero" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="layout_graph") +def layout_graph( + input: VerbInput, + callbacks: VerbCallbacks, + strategy: dict[str, Any], + embeddings_column: str, + graph_column: str, + to: str, + graph_to: str | None = None, + **_kwargs: dict, +) -> TableContainer: + """ + Apply a layout algorithm to a graph. The graph is expected to be in graphml format. The verb outputs a new column containing the laid out graph. + + ## Usage + ```yaml + verb: layout_graph + args: + graph_column: clustered_graph # The name of the column containing the graph, should be a graphml graph + embeddings_column: embeddings # The name of the column containing the embeddings + to: node_positions # The name of the column to output the node positions to + graph_to: positioned_graph # The name of the column to output the positioned graph to + strategy: # See strategies section below + ``` + + ## Strategies + The layout graph verb uses a strategy to layout the graph. The strategy is a json object which defines the strategy to use. The following strategies are available: + + ### umap + This strategy uses the umap algorithm to layout a graph. The strategy config is as follows: + ```yaml + strategy: + type: umap + n_neighbors: 5 # Optional, The number of neighbors to use for the umap algorithm, default: 5 + min_dist: 0.75 # Optional, The min distance to use for the umap algorithm, default: 0.75 + ``` + """ + output_df = cast(pd.DataFrame, input.get_input()) + + num_items = len(output_df) + strategy_type = strategy.get("type", LayoutGraphStrategyType.umap) + strategy_args = {**strategy} + + has_embeddings = embeddings_column in output_df.columns + + layouts = output_df.apply( + progress_callback( + lambda row: _run_layout( + strategy_type, + row[graph_column], + row[embeddings_column] if has_embeddings else {}, + strategy_args, + callbacks, + ), + callbacks.progress, + num_items, + ), + axis=1, + ) + output_df[to] = layouts.apply(lambda layout: [pos.to_pandas() for pos in layout]) + if graph_to is not None: + output_df[graph_to] = output_df.apply( + lambda row: _apply_layout_to_graph( + row[graph_column], cast(GraphLayout, layouts[row.name]) + ), + axis=1, + ) + return TableContainer(table=output_df) + + +def _run_layout( + strategy: LayoutGraphStrategyType, + graphml_or_graph: str | nx.Graph, + embeddings: NodeEmbeddings, + args: dict[str, Any], + reporter: VerbCallbacks, +) -> GraphLayout: + graph = load_graph(graphml_or_graph) + match strategy: + case LayoutGraphStrategyType.umap: + from .methods.umap import run as run_umap + + return run_umap( + graph, + embeddings, + args, + lambda e, stack, d: reporter.error("Error in Umap", e, stack, d), + ) + case LayoutGraphStrategyType.zero: + from .methods.zero import run as run_zero + + return run_zero( + graph, + args, + lambda e, stack, d: reporter.error("Error in Zero", e, stack, d), + ) + case _: + msg = f"Unknown strategy {strategy}" + raise ValueError(msg) + + +def _apply_layout_to_graph( + graphml_or_graph: str | nx.Graph, layout: GraphLayout +) -> str: + graph = load_graph(graphml_or_graph) + for node_position in layout: + if node_position.label in graph.nodes: + graph.nodes[node_position.label]["x"] = node_position.x + graph.nodes[node_position.label]["y"] = node_position.y + graph.nodes[node_position.label]["size"] = node_position.size + return "\n".join(nx.generate_graphml(graph)) diff --git a/func-app/graphrag/index/verbs/graph/layout/methods/__init__.py b/func-app/graphrag/index/verbs/graph/layout/methods/__init__.py new file mode 100644 index 0000000000..5d5054122b --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/layout/methods/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Graph Layout Methods.""" diff --git a/func-app/graphrag/index/verbs/graph/layout/methods/umap.py b/func-app/graphrag/index/verbs/graph/layout/methods/umap.py new file mode 100644 index 0000000000..a4bc7c2818 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/layout/methods/umap.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _create_node_position methods definitions.""" + +import logging +import traceback +from typing import Any + +import networkx as nx +import numpy as np + +from graphrag.index.graph.visualization import ( + GraphLayout, + NodePosition, + compute_umap_positions, +) +from graphrag.index.typing import ErrorHandlerFn +from graphrag.index.verbs.graph.embed.typing import NodeEmbeddings + +# TODO: This could be handled more elegantly, like what columns to use +# for "size" or "cluster" +# We could also have a boolean to indicate to use node sizes or clusters + +log = logging.getLogger(__name__) + + +def run( + graph: nx.Graph, + embeddings: NodeEmbeddings, + args: dict[str, Any], + on_error: ErrorHandlerFn, +) -> GraphLayout: + """Run method definition.""" + node_clusters = [] + node_sizes = [] + + embeddings = _filter_raw_embeddings(embeddings) + nodes = list(embeddings.keys()) + embedding_vectors = [embeddings[node_id] for node_id in nodes] + + for node_id in nodes: + node = graph.nodes[node_id] + cluster = node.get("cluster", node.get("community", -1)) + node_clusters.append(cluster) + size = node.get("degree", node.get("size", 0)) + node_sizes.append(size) + + additional_args = {} + if len(node_clusters) > 0: + additional_args["node_categories"] = node_clusters + if len(node_sizes) > 0: + additional_args["node_sizes"] = node_sizes + + try: + return compute_umap_positions( + embedding_vectors=np.array(embedding_vectors), + node_labels=nodes, + **additional_args, + min_dist=args.get("min_dist", 0.75), + n_neighbors=args.get("n_neighbors", 5), + ) + except Exception as e: + log.exception("Error running UMAP") + on_error(e, traceback.format_exc(), None) + # Umap may fail due to input sparseness or memory pressure. + # For now, in these cases, we'll just return a layout with all nodes at (0, 0) + result = [] + for i in range(len(nodes)): + cluster = node_clusters[i] if len(node_clusters) > 0 else 1 + result.append( + NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster)) + ) + return result + + +def _filter_raw_embeddings(embeddings: NodeEmbeddings) -> NodeEmbeddings: + return { + node_id: embedding + for node_id, embedding in embeddings.items() + if embedding is not None + } diff --git a/func-app/graphrag/index/verbs/graph/layout/methods/zero.py b/func-app/graphrag/index/verbs/graph/layout/methods/zero.py new file mode 100644 index 0000000000..f41d2d4ca4 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/layout/methods/zero.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _create_node_position methods definitions.""" + +import logging +import traceback +from typing import Any + +import networkx as nx + +from graphrag.index.graph.visualization import ( + GraphLayout, + NodePosition, + get_zero_positions, +) +from graphrag.index.typing import ErrorHandlerFn + +# TODO: This could be handled more elegantly, like what columns to use +# for "size" or "cluster" +# We could also have a boolean to indicate to use node sizes or clusters + +log = logging.getLogger(__name__) + + +def run( + graph: nx.Graph, + _args: dict[str, Any], + on_error: ErrorHandlerFn, +) -> GraphLayout: + """Run method definition.""" + node_clusters = [] + node_sizes = [] + + nodes = list(graph.nodes) + + for node_id in nodes: + node = graph.nodes[node_id] + cluster = node.get("cluster", node.get("community", -1)) + node_clusters.append(cluster) + size = node.get("degree", node.get("size", 0)) + node_sizes.append(size) + + additional_args = {} + if len(node_clusters) > 0: + additional_args["node_categories"] = node_clusters + if len(node_sizes) > 0: + additional_args["node_sizes"] = node_sizes + + try: + return get_zero_positions(node_labels=nodes, **additional_args) + except Exception as e: + log.exception("Error running zero-position") + on_error(e, traceback.format_exc(), None) + # Umap may fail due to input sparseness or memory pressure. + # For now, in these cases, we'll just return a layout with all nodes at (0, 0) + result = [] + for i in range(len(nodes)): + cluster = node_clusters[i] if len(node_clusters) > 0 else 1 + result.append( + NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster)) + ) + return result diff --git a/func-app/graphrag/index/verbs/graph/merge/__init__.py b/func-app/graphrag/index/verbs/graph/merge/__init__.py new file mode 100644 index 0000000000..f718827942 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/merge/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph merge package root.""" + +from .merge_graphs import merge_graphs + +__all__ = ["merge_graphs"] diff --git a/func-app/graphrag/index/verbs/graph/merge/defaults.py b/func-app/graphrag/index/verbs/graph/merge/defaults.py new file mode 100644 index 0000000000..80c60331c6 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/merge/defaults.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing DEFAULT_NODE_OPERATIONS, DEFAULT_EDGE_OPERATIONS and DEFAULT_CONCAT_SEPARATOR values definition.""" + +from .typing import BasicMergeOperation + +DEFAULT_NODE_OPERATIONS = { + "*": { + "operation": BasicMergeOperation.Replace, + } +} + +DEFAULT_EDGE_OPERATIONS = { + "*": { + "operation": BasicMergeOperation.Replace, + }, + "weight": "sum", +} + +DEFAULT_CONCAT_SEPARATOR = "," diff --git a/func-app/graphrag/index/verbs/graph/merge/merge_graphs.py b/func-app/graphrag/index/verbs/graph/merge/merge_graphs.py new file mode 100644 index 0000000000..8ab3fa47f7 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/merge/merge_graphs.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing merge_graphs, merge_nodes, merge_edges, merge_attributes, apply_merge_operation and _get_detailed_attribute_merge_operation methods definitions.""" + +from typing import Any, cast + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb + +from graphrag.index.utils import load_graph + +from .defaults import ( + DEFAULT_CONCAT_SEPARATOR, + DEFAULT_EDGE_OPERATIONS, + DEFAULT_NODE_OPERATIONS, +) +from .typing import ( + BasicMergeOperation, + DetailedAttributeMergeOperation, + NumericOperation, + StringOperation, +) + + +@verb(name="merge_graphs") +def merge_graphs( + input: VerbInput, + callbacks: VerbCallbacks, + column: str, + to: str, + nodes: dict[str, Any] = DEFAULT_NODE_OPERATIONS, + edges: dict[str, Any] = DEFAULT_EDGE_OPERATIONS, + **_kwargs, +) -> TableContainer: + """ + Merge multiple graphs together. The graphs are expected to be in graphml format. The verb outputs a new column containing the merged graph. + + > Note: This will merge all rows into a single graph. + + ## Usage + ```yaml + verb: merge_graph + args: + column: clustered_graph # The name of the column containing the graph, should be a graphml graph + to: merged_graph # The name of the column to output the merged graph to + nodes: # See node operations section below + edges: # See edge operations section below + ``` + + ## Node Operations + The merge graph verb can perform operations on the nodes of the graph. + + ### Usage + ```yaml + nodes: + : + ... for each attribute or use the special value "*" for all attributes + ``` + + ## Edge Operations + The merge graph verb can perform operations on the nodes of the graph. + + ### Usage + ```yaml + edges: + : + ... for each attribute or use the special value "*" for all attributes + ``` + + ## Operations + The merge graph verb can perform operations on the nodes and edges of the graph. The following operations are available: + + - __replace__: This operation replaces the attribute with the last value seen. + - __skip__: This operation skips the attribute, and just uses the first value seen. + - __concat__: This operation concatenates the attribute with the last value seen. + - __sum__: This operation sums the attribute with the last value seen. + - __max__: This operation takes the max of the attribute with the last value seen. + max + - __min__: This operation takes the min of the attribute with the last value seen. + - __average__: This operation takes the mean of the attribute with the last value seen. + - __multiply__: This operation multiplies the attribute with the last value seen. + """ + input_df = input.get_input() + output = pd.DataFrame() + + node_ops = { + attrib: _get_detailed_attribute_merge_operation(value) + for attrib, value in nodes.items() + } + edge_ops = { + attrib: _get_detailed_attribute_merge_operation(value) + for attrib, value in edges.items() + } + + mega_graph = nx.Graph() + num_total = len(input_df) + for graphml in progress_iterable(input_df[column], callbacks.progress, num_total): + graph = load_graph(cast(str | nx.Graph, graphml)) + merge_nodes(mega_graph, graph, node_ops) + merge_edges(mega_graph, graph, edge_ops) + + output[to] = ["\n".join(nx.generate_graphml(mega_graph))] + + return TableContainer(table=output) + + +def merge_nodes( + target: nx.Graph, + subgraph: nx.Graph, + node_ops: dict[str, DetailedAttributeMergeOperation], +): + """Merge nodes from subgraph into target using the operations defined in node_ops.""" + for node in subgraph.nodes: + if node not in target.nodes: + target.add_node(node, **(subgraph.nodes[node] or {})) + else: + merge_attributes(target.nodes[node], subgraph.nodes[node], node_ops) + + +def merge_edges( + target_graph: nx.Graph, + subgraph: nx.Graph, + edge_ops: dict[str, DetailedAttributeMergeOperation], +): + """Merge edges from subgraph into target using the operations defined in edge_ops.""" + for source, target, edge_data in subgraph.edges(data=True): # type: ignore + if not target_graph.has_edge(source, target): + target_graph.add_edge(source, target, **(edge_data or {})) + else: + merge_attributes(target_graph.edges[(source, target)], edge_data, edge_ops) + + +def merge_attributes( + target_item: dict[str, Any] | None, + source_item: dict[str, Any] | None, + ops: dict[str, DetailedAttributeMergeOperation], +): + """Merge attributes from source_item into target_item using the operations defined in ops.""" + source_item = source_item or {} + target_item = target_item or {} + for op_attrib, op in ops.items(): + if op_attrib == "*": + for attrib in source_item: + # If there is a specific handler for this attribute, use it + # i.e. * provides a default, but you can override it + if attrib not in ops: + apply_merge_operation(target_item, source_item, attrib, op) + else: + if op_attrib in source_item or op_attrib in target_item: + apply_merge_operation(target_item, source_item, op_attrib, op) + + +def apply_merge_operation( + target_item: dict[str, Any] | None, + source_item: dict[str, Any] | None, + attrib: str, + op: DetailedAttributeMergeOperation, +): + """Apply the merge operation to the attribute.""" + source_item = source_item or {} + target_item = target_item or {} + + if ( + op.operation == BasicMergeOperation.Replace + or op.operation == StringOperation.Replace + ): + target_item[attrib] = source_item.get(attrib, None) or "" + elif ( + op.operation == BasicMergeOperation.Skip or op.operation == StringOperation.Skip + ): + target_item[attrib] = target_item.get(attrib, None) or "" + elif op.operation == StringOperation.Concat: + separator = op.separator or DEFAULT_CONCAT_SEPARATOR + target_attrib = target_item.get(attrib, "") or "" + source_attrib = source_item.get(attrib, "") or "" + target_item[attrib] = f"{target_attrib}{separator}{source_attrib}" + if op.distinct: + # TODO: Slow + target_item[attrib] = separator.join( + sorted(set(target_item[attrib].split(separator))) + ) + + # We're assuming that the attribute is numeric + elif op.operation == NumericOperation.Sum: + target_item[attrib] = (target_item.get(attrib, 0) or 0) + ( + source_item.get(attrib, 0) or 0 + ) + elif op.operation == NumericOperation.Average: + target_item[attrib] = ( + (target_item.get(attrib, 0) or 0) + (source_item.get(attrib, 0) or 0) + ) / 2 + elif op.operation == NumericOperation.Max: + target_item[attrib] = max( + (target_item.get(attrib, 0) or 0), (source_item.get(attrib, 0) or 0) + ) + elif op.operation == NumericOperation.Min: + target_item[attrib] = min( + (target_item.get(attrib, 0) or 0), (source_item.get(attrib, 0) or 0) + ) + elif op.operation == NumericOperation.Multiply: + target_item[attrib] = (target_item.get(attrib, 1) or 1) * ( + source_item.get(attrib, 1) or 1 + ) + else: + msg = f"Invalid operation {op.operation}" + raise ValueError(msg) + + +def _get_detailed_attribute_merge_operation( + value: str | dict[str, Any], +) -> DetailedAttributeMergeOperation: + """Normalize the AttributeMergeOperation into a DetailedAttributeMergeOperation.""" + if isinstance(value, str): + return DetailedAttributeMergeOperation(operation=value) + return DetailedAttributeMergeOperation(**value) diff --git a/func-app/graphrag/index/verbs/graph/merge/typing.py b/func-app/graphrag/index/verbs/graph/merge/typing.py new file mode 100644 index 0000000000..0e534f516c --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/merge/typing.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'BasicMergeOperation', 'StringOperation', 'NumericOperation' and 'DetailedAttributeMergeOperation' models.""" + +from dataclasses import dataclass +from enum import Enum + + +class BasicMergeOperation(str, Enum): + """Basic Merge Operation class definition.""" + + Replace = "replace" + Skip = "skip" + + +class StringOperation(str, Enum): + """String Operation class definition.""" + + Concat = "concat" + Replace = "replace" + Skip = "skip" + + +class NumericOperation(str, Enum): + """Numeric Operation class definition.""" + + Sum = "sum" + Average = "average" + Max = "max" + Min = "min" + Multiply = "multiply" + Replace = "replace" + Skip = "skip" + + +@dataclass +class DetailedAttributeMergeOperation: + """Detailed attribute merge operation class definition.""" + + operation: str # StringOperation | NumericOperation + + # concat + separator: str | None = None + delimiter: str | None = None + distinct: bool = False + + +AttributeMergeOperation = str | DetailedAttributeMergeOperation diff --git a/func-app/graphrag/index/verbs/graph/report/__init__.py b/func-app/graphrag/index/verbs/graph/report/__init__.py new file mode 100644 index 0000000000..e47d9ccef5 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph report package root.""" + +from .create_community_reports import ( + CreateCommunityReportsStrategyType, + create_community_reports, +) +from .prepare_community_reports import prepare_community_reports +from .prepare_community_reports_claims import prepare_community_reports_claims +from .prepare_community_reports_edges import prepare_community_reports_edges +from .prepare_community_reports_nodes import prepare_community_reports_nodes +from .restore_community_hierarchy import restore_community_hierarchy + +__all__ = [ + "CreateCommunityReportsStrategyType", + "create_community_reports", + "create_community_reports", + "prepare_community_reports", + "prepare_community_reports_claims", + "prepare_community_reports_edges", + "prepare_community_reports_nodes", + "restore_community_hierarchy", +] diff --git a/func-app/graphrag/index/verbs/graph/report/create_community_reports.py b/func-app/graphrag/index/verbs/graph/report/create_community_reports.py new file mode 100644 index 0000000000..c67d5107e8 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/create_community_reports.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_community_reports and load_strategy methods definition.""" + +import logging +from enum import Enum +from typing import cast + +import pandas as pd +from datashaper import ( + AsyncType, + NoopVerbCallbacks, + TableContainer, + VerbCallbacks, + VerbInput, + derive_from_rows, + progress_ticker, + verb, +) + +import graphrag.config.defaults as defaults +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.index.cache import PipelineCache +from graphrag.index.graph.extractors.community_reports import ( + get_levels, + prep_community_report_context, +) +from graphrag.index.utils.ds_util import get_required_input_table + +from .strategies.typing import CommunityReport, CommunityReportsStrategy + +log = logging.getLogger(__name__) + + +class CreateCommunityReportsStrategyType(str, Enum): + """CreateCommunityReportsStrategyType class definition.""" + + graph_intelligence = "graph_intelligence" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="create_community_reports") +async def create_community_reports( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + strategy: dict, + async_mode: AsyncType = AsyncType.AsyncIO, + num_threads: int = 4, + **_kwargs, +) -> TableContainer: + """Generate entities for each row, and optionally a graph of those entities.""" + log.debug("create_community_reports strategy=%s", strategy) + local_contexts = cast(pd.DataFrame, input.get_input()) + nodes_ctr = get_required_input_table(input, "nodes") + nodes = cast(pd.DataFrame, nodes_ctr.table) + community_hierarchy_ctr = get_required_input_table(input, "community_hierarchy") + community_hierarchy = cast(pd.DataFrame, community_hierarchy_ctr.table) + + levels = get_levels(nodes) + reports: list[CommunityReport | None] = [] + tick = progress_ticker(callbacks.progress, len(local_contexts)) + runner = load_strategy(strategy["type"]) + + for level in levels: + level_contexts = prep_community_report_context( + pd.DataFrame(reports), + local_context_df=local_contexts, + community_hierarchy_df=community_hierarchy, + level=level, + max_tokens=strategy.get( + "max_input_tokens", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH + ), + ) + + async def run_generate(record): + result = await _generate_report( + runner, + community_id=record[schemas.NODE_COMMUNITY], + community_level=record[schemas.COMMUNITY_LEVEL], + community_context=record[schemas.CONTEXT_STRING], + cache=cache, + callbacks=callbacks, + strategy=strategy, + ) + tick() + return result + + local_reports = await derive_from_rows( + level_contexts, + run_generate, + callbacks=NoopVerbCallbacks(), + num_threads=num_threads, + scheduling_type=async_mode, + ) + reports.extend([lr for lr in local_reports if lr is not None]) + + return TableContainer(table=pd.DataFrame(reports)) + + +async def _generate_report( + runner: CommunityReportsStrategy, + cache: PipelineCache, + callbacks: VerbCallbacks, + strategy: dict, + community_id: int | str, + community_level: int, + community_context: str, +) -> CommunityReport | None: + """Generate a report for a single community.""" + return await runner( + community_id, community_context, community_level, callbacks, cache, strategy + ) + + +def load_strategy( + strategy: CreateCommunityReportsStrategyType, +) -> CommunityReportsStrategy: + """Load strategy method definition.""" + match strategy: + case CreateCommunityReportsStrategyType.graph_intelligence: + from .strategies.graph_intelligence import run + + return run + case _: + msg = f"Unknown strategy: {strategy}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/graph/report/prepare_community_reports.py b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports.py new file mode 100644 index 0000000000..3c9ebd451a --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_community_reports and load_strategy methods definition.""" + +import logging +from typing import cast + +import pandas as pd +from datashaper import ( + TableContainer, + VerbCallbacks, + VerbInput, + progress_iterable, + verb, +) + +import graphrag.index.graph.extractors.community_reports.schemas as schemas +from graphrag.index.graph.extractors.community_reports import ( + filter_claims_to_nodes, + filter_edges_to_nodes, + filter_nodes_to_level, + get_levels, + set_context_exceeds_flag, + set_context_size, + sort_context, +) +from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table + +log = logging.getLogger(__name__) + + +@verb(name="prepare_community_reports") +def prepare_community_reports( + input: VerbInput, + callbacks: VerbCallbacks, + max_tokens: int = 16_000, + **_kwargs, +) -> TableContainer: + """Generate entities for each row, and optionally a graph of those entities.""" + # Prepare Community Reports + node_df = cast(pd.DataFrame, get_required_input_table(input, "nodes").table) + edge_df = cast(pd.DataFrame, get_required_input_table(input, "edges").table) + claim_df = get_named_input_table(input, "claims") + if claim_df is not None: + claim_df = cast(pd.DataFrame, claim_df.table) + + levels = get_levels(node_df, schemas.NODE_LEVEL) + dfs = [] + + for level in progress_iterable(levels, callbacks.progress, len(levels)): + communities_at_level_df = _prepare_reports_at_level( + node_df, edge_df, claim_df, level, max_tokens + ) + dfs.append(communities_at_level_df) + + # build initial local context for all communities + return TableContainer(table=pd.concat(dfs)) + + +def _prepare_reports_at_level( + node_df: pd.DataFrame, + edge_df: pd.DataFrame, + claim_df: pd.DataFrame | None, + level: int, + max_tokens: int = 16_000, + community_id_column: str = schemas.COMMUNITY_ID, + node_id_column: str = schemas.NODE_ID, + node_name_column: str = schemas.NODE_NAME, + node_details_column: str = schemas.NODE_DETAILS, + node_level_column: str = schemas.NODE_LEVEL, + node_degree_column: str = schemas.NODE_DEGREE, + node_community_column: str = schemas.NODE_COMMUNITY, + edge_id_column: str = schemas.EDGE_ID, + edge_source_column: str = schemas.EDGE_SOURCE, + edge_target_column: str = schemas.EDGE_TARGET, + edge_degree_column: str = schemas.EDGE_DEGREE, + edge_details_column: str = schemas.EDGE_DETAILS, + claim_id_column: str = schemas.CLAIM_ID, + claim_subject_column: str = schemas.CLAIM_SUBJECT, + claim_details_column: str = schemas.CLAIM_DETAILS, +): + def get_edge_details(node_df: pd.DataFrame, edge_df: pd.DataFrame, name_col: str): + return node_df.merge( + cast( + pd.DataFrame, + edge_df[[name_col, schemas.EDGE_DETAILS]], + ).rename(columns={name_col: schemas.NODE_NAME}), + on=schemas.NODE_NAME, + how="left", + ) + + level_node_df = filter_nodes_to_level(node_df, level) + log.info("Number of nodes at level=%s => %s", level, len(level_node_df)) + nodes = level_node_df[node_name_column].tolist() + + # Filter edges & claims to those containing the target nodes + level_edge_df = filter_edges_to_nodes(edge_df, nodes) + level_claim_df = ( + filter_claims_to_nodes(claim_df, nodes) if claim_df is not None else None + ) + + # concat all edge details per node + merged_node_df = pd.concat( + [ + get_edge_details(level_node_df, level_edge_df, edge_source_column), + get_edge_details(level_node_df, level_edge_df, edge_target_column), + ], + axis=0, + ) + merged_node_df = ( + merged_node_df.groupby([ + node_name_column, + node_community_column, + node_degree_column, + node_level_column, + ]) + .agg({node_details_column: "first", edge_details_column: list}) + .reset_index() + ) + + # concat claim details per node + if level_claim_df is not None: + merged_node_df = merged_node_df.merge( + cast( + pd.DataFrame, + level_claim_df[[claim_subject_column, claim_details_column]], + ).rename(columns={claim_subject_column: node_name_column}), + on=node_name_column, + how="left", + ) + merged_node_df = ( + merged_node_df.groupby([ + node_name_column, + node_community_column, + node_level_column, + node_degree_column, + ]) + .agg({ + node_details_column: "first", + edge_details_column: "first", + **({claim_details_column: list} if level_claim_df is not None else {}), + }) + .reset_index() + ) + + # concat all node details, including name, degree, node_details, edge_details, and claim_details + merged_node_df[schemas.ALL_CONTEXT] = merged_node_df.apply( + lambda x: { + node_name_column: x[node_name_column], + node_degree_column: x[node_degree_column], + node_details_column: x[node_details_column], + edge_details_column: x[edge_details_column], + claim_details_column: x[claim_details_column] + if level_claim_df is not None + else [], + }, + axis=1, + ) + + # group all node details by community + community_df = ( + merged_node_df.groupby(node_community_column) + .agg({schemas.ALL_CONTEXT: list}) + .reset_index() + ) + community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply( + lambda x: sort_context( + x, + node_id_column=node_id_column, + node_name_column=node_name_column, + node_details_column=node_details_column, + edge_id_column=edge_id_column, + edge_details_column=edge_details_column, + edge_degree_column=edge_degree_column, + edge_source_column=edge_source_column, + edge_target_column=edge_target_column, + claim_id_column=claim_id_column, + claim_details_column=claim_details_column, + community_id_column=community_id_column, + ) + ) + set_context_size(community_df) + set_context_exceeds_flag(community_df, max_tokens) + + community_df[schemas.COMMUNITY_LEVEL] = level + return community_df diff --git a/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_claims.py b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_claims.py new file mode 100644 index 0000000000..aa9a790772 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_claims.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.graph.extractors.community_reports.schemas import ( + CLAIM_DESCRIPTION, + CLAIM_DETAILS, + CLAIM_ID, + CLAIM_STATUS, + CLAIM_SUBJECT, + CLAIM_TYPE, +) + +_MISSING_DESCRIPTION = "No Description" + + +@verb(name="prepare_community_reports_claims") +def prepare_community_reports_claims( + input: VerbInput, + to: str = CLAIM_DETAILS, + id_column: str = CLAIM_ID, + description_column: str = CLAIM_DESCRIPTION, + subject_column: str = CLAIM_SUBJECT, + type_column: str = CLAIM_TYPE, + status_column: str = CLAIM_STATUS, + **_kwargs, +) -> TableContainer: + """Merge claim details into an object.""" + claim_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()) + claim_df = claim_df.fillna(value={description_column: _MISSING_DESCRIPTION}) + + # merge values of five columns into a map column + claim_df[to] = claim_df.apply( + lambda x: { + id_column: x[id_column], + subject_column: x[subject_column], + type_column: x[type_column], + status_column: x[status_column], + description_column: x[description_column], + }, + axis=1, + ) + + return TableContainer(table=claim_df) diff --git a/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_edges.py b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_edges.py new file mode 100644 index 0000000000..b568aba006 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_edges.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.graph.extractors.community_reports.schemas import ( + EDGE_DEGREE, + EDGE_DESCRIPTION, + EDGE_DETAILS, + EDGE_ID, + EDGE_SOURCE, + EDGE_TARGET, +) + +_MISSING_DESCRIPTION = "No Description" + + +@verb(name="prepare_community_reports_edges") +def prepare_community_reports_edges( + input: VerbInput, + to: str = EDGE_DETAILS, + id_column: str = EDGE_ID, + source_column: str = EDGE_SOURCE, + target_column: str = EDGE_TARGET, + description_column: str = EDGE_DESCRIPTION, + degree_column: str = EDGE_DEGREE, + **_kwargs, +) -> TableContainer: + """Merge edge details into an object.""" + edge_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()).fillna( + value={description_column: _MISSING_DESCRIPTION} + ) + edge_df[to] = edge_df.apply( + lambda x: { + id_column: x[id_column], + source_column: x[source_column], + target_column: x[target_column], + description_column: x[description_column], + degree_column: x[degree_column], + }, + axis=1, + ) + return TableContainer(table=edge_df) diff --git a/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_nodes.py b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_nodes.py new file mode 100644 index 0000000000..f159c125ee --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/prepare_community_reports_nodes.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.graph.extractors.community_reports.schemas import ( + NODE_DEGREE, + NODE_DESCRIPTION, + NODE_DETAILS, + NODE_ID, + NODE_NAME, +) + +_MISSING_DESCRIPTION = "No Description" + + +@verb(name="prepare_community_reports_nodes") +def prepare_community_reports_nodes( + input: VerbInput, + to: str = NODE_DETAILS, + id_column: str = NODE_ID, + name_column: str = NODE_NAME, + description_column: str = NODE_DESCRIPTION, + degree_column: str = NODE_DEGREE, + **_kwargs, +) -> TableContainer: + """Merge edge details into an object.""" + node_df = cast(pd.DataFrame, input.get_input()) + node_df = node_df.fillna(value={description_column: _MISSING_DESCRIPTION}) + + # merge values of four columns into a map column + node_df[to] = node_df.apply( + lambda x: { + id_column: x[id_column], + name_column: x[name_column], + description_column: x[description_column], + degree_column: x[degree_column], + }, + axis=1, + ) + return TableContainer(table=node_df) diff --git a/func-app/graphrag/index/verbs/graph/report/restore_community_hierarchy.py b/func-app/graphrag/index/verbs/graph/report/restore_community_hierarchy.py new file mode 100644 index 0000000000..437369f0e5 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/restore_community_hierarchy.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph, _get_node_attributes, _get_edge_attributes and _get_attribute_column_mapping methods definition.""" + +import logging +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +import graphrag.index.graph.extractors.community_reports.schemas as schemas + +log = logging.getLogger(__name__) + + +@verb(name="restore_community_hierarchy") +def restore_community_hierarchy( + input: VerbInput, + name_column: str = schemas.NODE_NAME, + community_column: str = schemas.NODE_COMMUNITY, + level_column: str = schemas.NODE_LEVEL, + **_kwargs, +) -> TableContainer: + """Restore the community hierarchy from the node data.""" + node_df: pd.DataFrame = cast(pd.DataFrame, input.get_input()) + community_df = ( + node_df.groupby([community_column, level_column]) + .agg({name_column: list}) + .reset_index() + ) + community_levels = {} + for _, row in community_df.iterrows(): + level = row[level_column] + name = row[name_column] + community = row[community_column] + + if community_levels.get(level) is None: + community_levels[level] = {} + community_levels[level][community] = name + + # get unique levels, sorted in ascending order + levels = sorted(community_levels.keys()) + + community_hierarchy = [] + + for idx in range(len(levels) - 1): + level = levels[idx] + log.debug("Level: %s", level) + next_level = levels[idx + 1] + current_level_communities = community_levels[level] + next_level_communities = community_levels[next_level] + log.debug( + "Number of communities at level %s: %s", + level, + len(current_level_communities), + ) + + for current_community in current_level_communities: + current_entities = current_level_communities[current_community] + + # loop through next level's communities to find all the subcommunities + entities_found = 0 + for next_level_community in next_level_communities: + next_entities = next_level_communities[next_level_community] + if set(next_entities).issubset(set(current_entities)): + community_hierarchy.append({ + community_column: current_community, + schemas.COMMUNITY_LEVEL: level, + schemas.SUB_COMMUNITY: next_level_community, + schemas.SUB_COMMUNITY_SIZE: len(next_entities), + }) + + entities_found += len(next_entities) + if entities_found == len(current_entities): + break + + return TableContainer(table=pd.DataFrame(community_hierarchy)) diff --git a/func-app/graphrag/index/verbs/graph/report/strategies/__init__.py b/func-app/graphrag/index/verbs/graph/report/strategies/__init__.py new file mode 100644 index 0000000000..87d1f9e252 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph report strategies package root.""" diff --git a/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/__init__.py b/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/__init__.py new file mode 100644 index 0000000000..7f51d7909b --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine graph report strategies graph intelligence package root.""" + +from .run_graph_intelligence import run + +__all__ = ["run"] diff --git a/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py b/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py new file mode 100644 index 0000000000..708d48d2b6 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/strategies/graph_intelligence/defaults.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing DEFAULT_CHUNK_SIZE and MOCK_RESPONSES definitions.""" + +import json + +DEFAULT_CHUNK_SIZE = 3000 +MOCK_RESPONSES = [ + json.dumps({ + "title": "", + "summary": "", + "rating": 2, + "rating_explanation": "", + "findings": [ + { + "summary": "", + "explanation": "", + "explanation": " CommunityReport | None: + """Run the graph intelligence entity extraction strategy.""" + llm_config = args.get( + "llm", {"type": LLMType.StaticResponse, "responses": MOCK_RESPONSES} + ) + llm_type = llm_config.get("type", LLMType.StaticResponse) + llm = load_llm( + "community_reporting", llm_type, reporter, pipeline_cache, llm_config + ) + return await _run_extractor(llm, community, input, level, args, reporter) + + +async def _run_extractor( + llm: CompletionLLM, + community: str | int, + input: str, + level: int, + args: StrategyConfig, + reporter: VerbCallbacks, +) -> CommunityReport | None: + # RateLimiter + rate_limiter = RateLimiter(rate=1, per=60) + extractor = CommunityReportsExtractor( + llm, + extraction_prompt=args.get("extraction_prompt", None), + max_report_length=args.get("max_report_length", None), + on_error=lambda e, stack, _data: reporter.error( + "Community Report Extraction Error", e, stack + ), + ) + + try: + await rate_limiter.acquire() + results = await extractor({"input_text": input}) + report = results.structured_output + if report is None or len(report.keys()) == 0: + log.warning("No report found for community: %s", community) + return None + + return CommunityReport( + community=community, + full_content=results.output, + level=level, + rank=_parse_rank(report), + title=report.get("title", f"Community Report: {community}"), + rank_explanation=report.get("rating_explanation", ""), + summary=report.get("summary", ""), + findings=report.get("findings", []), + full_content_json=json.dumps(report, indent=4), + ) + except Exception as e: + log.exception("Error processing community: %s", community) + reporter.error("Community Report Extraction Error", e, traceback.format_exc()) + return None + + +def _parse_rank(report: dict) -> float: + rank = report.get("rating", -1) + try: + return float(rank) + except ValueError: + log.exception("Error parsing rank: %s defaulting to -1", rank) + return -1 diff --git a/func-app/graphrag/index/verbs/graph/report/strategies/typing.py b/func-app/graphrag/index/verbs/graph/report/strategies/typing.py new file mode 100644 index 0000000000..087c724702 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/report/strategies/typing.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Finding' and 'CommunityReport' models.""" + +from collections.abc import Awaitable, Callable +from typing import Any + +from datashaper import VerbCallbacks +from typing_extensions import TypedDict + +from graphrag.index.cache import PipelineCache + +ExtractedEntity = dict[str, Any] +StrategyConfig = dict[str, Any] +RowContext = dict[str, Any] +EntityTypes = list[str] +Claim = dict[str, Any] + + +class Finding(TypedDict): + """Finding class definition.""" + + summary: str + explanation: str + + +class CommunityReport(TypedDict): + """Community report class definition.""" + + community: str | int + title: str + summary: str + full_content: str + full_content_json: str + rank: float + level: int + rank_explanation: str + findings: list[Finding] + + +CommunityReportsStrategy = Callable[ + [ + str | int, + str, + int, + VerbCallbacks, + PipelineCache, + StrategyConfig, + ], + Awaitable[CommunityReport | None], +] diff --git a/func-app/graphrag/index/verbs/graph/unpack.py b/func-app/graphrag/index/verbs/graph/unpack.py new file mode 100644 index 0000000000..ffb7f4b0a2 --- /dev/null +++ b/func-app/graphrag/index/verbs/graph/unpack.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing unpack_graph, _run_unpack, _unpack_nodes and _unpack_edges methods definition.""" + +from typing import Any, cast + +import networkx as nx +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, progress_iterable, verb + +from graphrag.index.utils import load_graph + +default_copy = ["level"] + + +@verb(name="unpack_graph") +def unpack_graph( + input: VerbInput, + callbacks: VerbCallbacks, + column: str, + type: str, # noqa A002 + copy: list[str] | None = None, + embeddings_column: str = "embeddings", + **kwargs, +) -> TableContainer: + """ + Unpack nodes or edges from a graphml graph, into a list of nodes or edges. + + This verb will create columns for each attribute in a node or edge. + + ## Usage + ```yaml + verb: unpack_graph + args: + type: node # The type of data to unpack, one of: node, edge. node will create a node list, edge will create an edge list + column: # The name of the column containing the graph, should be a graphml graph + ``` + """ + if copy is None: + copy = default_copy + input_df = input.get_input() + num_total = len(input_df) + result = [] + copy = [col for col in copy if col in input_df.columns] + has_embeddings = embeddings_column in input_df.columns + + for _, row in progress_iterable(input_df.iterrows(), callbacks.progress, num_total): + # merge the original row with the unpacked graph item + cleaned_row = {col: row[col] for col in copy} + embeddings = ( + cast(dict[str, list[float]], row[embeddings_column]) + if has_embeddings + else {} + ) + + result.extend([ + {**cleaned_row, **graph_id} + for graph_id in _run_unpack( + cast(str | nx.Graph, row[column]), + type, + embeddings, + kwargs, + ) + ]) + + output_df = pd.DataFrame(result) + return TableContainer(table=output_df) + + +def _run_unpack( + graphml_or_graph: str | nx.Graph, + unpack_type: str, + embeddings: dict[str, list[float]], + args: dict[str, Any], +) -> list[dict[str, Any]]: + graph = load_graph(graphml_or_graph) + if unpack_type == "nodes": + return _unpack_nodes(graph, embeddings, args) + if unpack_type == "edges": + return _unpack_edges(graph, args) + msg = f"Unknown type {unpack_type}" + raise ValueError(msg) + + +def _unpack_nodes( + graph: nx.Graph, embeddings: dict[str, list[float]], _args: dict[str, Any] +) -> list[dict[str, Any]]: + return [ + { + "label": label, + **(node_data or {}), + "graph_embedding": embeddings.get(label), + } + for label, node_data in graph.nodes(data=True) # type: ignore + ] + + +def _unpack_edges(graph: nx.Graph, _args: dict[str, Any]) -> list[dict[str, Any]]: + return [ + { + "source": source_id, + "target": target_id, + **(edge_data or {}), + } + for source_id, target_id, edge_data in graph.edges(data=True) # type: ignore + ] diff --git a/func-app/graphrag/index/verbs/overrides/__init__.py b/func-app/graphrag/index/verbs/overrides/__init__.py new file mode 100644 index 0000000000..24b82c1f3e --- /dev/null +++ b/func-app/graphrag/index/verbs/overrides/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine overrides package root.""" + +from .aggregate import aggregate +from .concat import concat +from .merge import merge + +__all__ = ["aggregate", "concat", "merge"] diff --git a/func-app/graphrag/index/verbs/overrides/aggregate.py b/func-app/graphrag/index/verbs/overrides/aggregate.py new file mode 100644 index 0000000000..df2137046b --- /dev/null +++ b/func-app/graphrag/index/verbs/overrides/aggregate.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Aggregation' model.""" + +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +from dataclasses import dataclass +from typing import Any, cast + +import pandas as pd +from datashaper import ( + FieldAggregateOperation, + Progress, + TableContainer, + VerbCallbacks, + VerbInput, + aggregate_operation_mapping, + verb, +) + +ARRAY_AGGREGATIONS = [ + FieldAggregateOperation.ArrayAgg, + FieldAggregateOperation.ArrayAggDistinct, +] + + +# TODO: This thing is kinda gross +# Also, it diverges from the original aggregate verb, since it doesn't support the same syntax +@verb(name="aggregate_override") +def aggregate( + input: VerbInput, + callbacks: VerbCallbacks, + aggregations: list[dict[str, Any]], + groupby: list[str] | None = None, + **_kwargs: dict, +) -> TableContainer: + """Aggregate method definition.""" + aggregations_to_apply = _load_aggregations(aggregations) + df_aggregations = { + agg.column: _get_pandas_agg_operation(agg) + for agg in aggregations_to_apply.values() + } + input_table = input.get_input() + callbacks.progress(Progress(percent=0)) + + if groupby is None: + output_grouped = input_table.groupby(lambda _x: True) + else: + output_grouped = input_table.groupby(groupby, sort=False) + output = cast(pd.DataFrame, output_grouped.agg(df_aggregations)) + output.rename( + columns={agg.column: agg.to for agg in aggregations_to_apply.values()}, + inplace=True, + ) + output.columns = [agg.to for agg in aggregations_to_apply.values()] + + callbacks.progress(Progress(percent=1)) + + return TableContainer(table=output.reset_index()) + + +@dataclass +class Aggregation: + """Aggregation class method definition.""" + + column: str | None + operation: str + to: str + + # Only useful for the concat operation + separator: str | None = None + + +def _get_pandas_agg_operation(agg: Aggregation) -> Any: + # TODO: Merge into datashaper + if agg.operation == "string_concat": + return (agg.separator or ",").join + return aggregate_operation_mapping[FieldAggregateOperation(agg.operation)] + + +def _load_aggregations( + aggregations: list[dict[str, Any]], +) -> dict[str, Aggregation]: + return { + aggregation["column"]: Aggregation( + aggregation["column"], aggregation["operation"], aggregation["to"] + ) + for aggregation in aggregations + } diff --git a/func-app/graphrag/index/verbs/overrides/concat.py b/func-app/graphrag/index/verbs/overrides/concat.py new file mode 100644 index 0000000000..7a0f0e2c32 --- /dev/null +++ b/func-app/graphrag/index/verbs/overrides/concat.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing concat method definition.""" + +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + + +@verb(name="concat_override") +def concat( + input: VerbInput, + columnwise: bool = False, + **_kwargs: dict, +) -> TableContainer: + """Concat method definition.""" + input_table = cast(pd.DataFrame, input.get_input()) + others = cast(list[pd.DataFrame], input.get_others()) + if columnwise: + output = pd.concat([input_table, *others], axis=1) + else: + output = pd.concat([input_table, *others], ignore_index=True) + return TableContainer(table=output) diff --git a/func-app/graphrag/index/verbs/overrides/merge.py b/func-app/graphrag/index/verbs/overrides/merge.py new file mode 100644 index 0000000000..64684c9828 --- /dev/null +++ b/func-app/graphrag/index/verbs/overrides/merge.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing merge and _merge_json methods definition.""" + +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +import logging +from enum import Enum +from typing import Any, cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, VerbResult, verb +from datashaper.engine.verbs.merge import merge as ds_merge + +log = logging.getLogger(__name__) + + +class MergeStrategyType(str, Enum): + """MergeStrategy class definition.""" + + json = "json" + datashaper = "datashaper" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +# TODO: This thing is kinda gross +# Also, it diverges from the original aggregate verb, since it doesn't support the same syntax +@verb(name="merge_override") +def merge( + input: VerbInput, + to: str, + columns: list[str], + strategy: MergeStrategyType = MergeStrategyType.datashaper, + delimiter: str = "", + preserveSource: bool = False, # noqa N806 + unhot: bool = False, + prefix: str = "", + **_kwargs: dict, +) -> TableContainer | VerbResult: + """Merge method definition.""" + output: pd.DataFrame + match strategy: + case MergeStrategyType.json: + output = _merge_json(input, to, columns) + filtered_list: list[str] = [] + + for col in output.columns: + try: + columns.index(col) + except ValueError: + log.exception("Column %s not found in input columns", col) + filtered_list.append(col) + + if not preserveSource: + output = cast(Any, output[filtered_list]) + return TableContainer(table=output.reset_index()) + case _: + return ds_merge( + input, to, columns, strategy, delimiter, preserveSource, unhot, prefix + ) + + +def _merge_json( + input: VerbInput, + to: str, + columns: list[str], +) -> pd.DataFrame: + input_table = cast(pd.DataFrame, input.get_input()) + output = input_table + output[to] = output[columns].apply( + lambda row: ({**row}), + axis=1, + ) + return output diff --git a/func-app/graphrag/index/verbs/snapshot.py b/func-app/graphrag/index/verbs/snapshot.py new file mode 100644 index 0000000000..b781478532 --- /dev/null +++ b/func-app/graphrag/index/verbs/snapshot.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing snapshot method definition.""" + +from datashaper import TableContainer, VerbInput, verb + +from graphrag.common.storage import PipelineStorage + + +@verb(name="snapshot") +async def snapshot( + input: VerbInput, + name: str, + formats: list[str], + storage: PipelineStorage, + **_kwargs: dict, +) -> TableContainer: + """Take a entire snapshot of the tabular data.""" + data = input.get_input() + + for fmt in formats: + if fmt == "parquet": + await storage.set(name + ".parquet", data.to_parquet()) + elif fmt == "json": + await storage.set( + name + ".json", data.to_json(orient="records", lines=True) + ) + + return TableContainer(table=data) diff --git a/func-app/graphrag/index/verbs/snapshot_rows.py b/func-app/graphrag/index/verbs/snapshot_rows.py new file mode 100644 index 0000000000..0b0ca1c3b6 --- /dev/null +++ b/func-app/graphrag/index/verbs/snapshot_rows.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'FormatSpecifier' model.""" + +import json +from dataclasses import dataclass +from typing import Any + +from datashaper import TableContainer, VerbInput, verb + +from graphrag.common.storage import PipelineStorage + + +@dataclass +class FormatSpecifier: + """Format specifier class definition.""" + + format: str + extension: str + + +@verb(name="snapshot_rows") +async def snapshot_rows( + input: VerbInput, + column: str | None, + base_name: str, + storage: PipelineStorage, + formats: list[str | dict[str, Any]], + row_name_column: str | None = None, + **_kwargs: dict, +) -> TableContainer: + """Take a by-row snapshot of the tabular data.""" + data = input.get_input() + parsed_formats = _parse_formats(formats) + num_rows = len(data) + + def get_row_name(row: Any, row_idx: Any): + if row_name_column is None: + if num_rows == 1: + return base_name + return f"{base_name}.{row_idx}" + return f"{base_name}.{row[row_name_column]}" + + for row_idx, row in data.iterrows(): + for fmt in parsed_formats: + row_name = get_row_name(row, row_idx) + extension = fmt.extension + if fmt.format == "json": + await storage.set( + f"{row_name}.{extension}", + json.dumps(row[column]) + if column is not None + else json.dumps(row.to_dict()), + ) + elif fmt.format == "text": + if column is None: + msg = "column must be specified for text format" + raise ValueError(msg) + await storage.set(f"{row_name}.{extension}", str(row[column])) + + return TableContainer(table=data) + + +def _parse_formats(formats: list[str | dict[str, Any]]) -> list[FormatSpecifier]: + """Parse the formats into a list of FormatSpecifiers.""" + return [ + FormatSpecifier(**fmt) + if isinstance(fmt, dict) + else FormatSpecifier(format=fmt, extension=_get_format_extension(fmt)) + for fmt in formats + ] + + +def _get_format_extension(fmt: str) -> str: + """Get the file extension for a given format.""" + if fmt == "json": + return "json" + if fmt == "text": + return "txt" + if fmt == "parquet": + return "parquet" + if fmt == "csv": + return "csv" + msg = f"Unknown format: {fmt}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/spread_json.py b/func-app/graphrag/index/verbs/spread_json.py new file mode 100644 index 0000000000..38656e12a4 --- /dev/null +++ b/func-app/graphrag/index/verbs/spread_json.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing spread_json method definition.""" + +import logging + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from graphrag.index.utils import is_null + +# TODO: Check if this is already a thing +DEFAULT_COPY = ["level"] + + +@verb(name="spread_json") +def spread_json( + input: VerbInput, + column: str, + copy: list[str] | None = None, + **_kwargs: dict, +) -> TableContainer: + """ + Unpack a column containing a tuple into multiple columns. + + id|json|b + 1|{"x":5,"y":6}|b + + is converted to + + id|x|y|b + -------- + 1|5|6|b + """ + if copy is None: + copy = DEFAULT_COPY + data = input.get_input() + + results = [] + for _, row in data.iterrows(): + try: + cleaned_row = {col: row[col] for col in copy} + rest_row = row[column] if row[column] is not None else {} + + if is_null(rest_row): + rest_row = {} + + results.append({**cleaned_row, **rest_row}) # type: ignore + except Exception: + logging.exception("Error spreading row: %s", row) + raise + data = pd.DataFrame(results, index=data.index) + + return TableContainer(table=data) diff --git a/func-app/graphrag/index/verbs/text/__init__.py b/func-app/graphrag/index/verbs/text/__init__.py new file mode 100644 index 0000000000..032f45e1b1 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text package root.""" + +from .chunk.text_chunk import chunk +from .embed import text_embed +from .replace import replace +from .split import text_split +from .translate import text_translate + +__all__ = [ + "chunk", + "replace", + "text_embed", + "text_split", + "text_translate", +] diff --git a/func-app/graphrag/index/verbs/text/chunk/__init__.py b/func-app/graphrag/index/verbs/text/chunk/__init__.py new file mode 100644 index 0000000000..4e2a7729c5 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text chunk package root.""" + +from .text_chunk import ChunkStrategy, ChunkStrategyType, chunk + +__all__ = ["ChunkStrategy", "ChunkStrategyType", "chunk"] diff --git a/func-app/graphrag/index/verbs/text/chunk/strategies/__init__.py b/func-app/graphrag/index/verbs/text/chunk/strategies/__init__.py new file mode 100644 index 0000000000..0f15fcb2d5 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text chunk strategies package root.""" diff --git a/func-app/graphrag/index/verbs/text/chunk/strategies/sentence.py b/func-app/graphrag/index/verbs/text/chunk/strategies/sentence.py new file mode 100644 index 0000000000..687def1d90 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/strategies/sentence.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run method definition.""" + +from collections.abc import Iterable +from typing import Any + +import nltk +from datashaper import ProgressTicker + +from .typing import TextChunk + + +def run( + input: list[str], _args: dict[str, Any], tick: ProgressTicker +) -> Iterable[TextChunk]: + """Chunks text into multiple parts. A pipeline verb.""" + for doc_idx, text in enumerate(input): + sentences = nltk.sent_tokenize(text) + for sentence in sentences: + yield TextChunk( + text_chunk=sentence, + source_doc_indices=[doc_idx], + ) + tick(1) diff --git a/func-app/graphrag/index/verbs/text/chunk/strategies/tokens.py b/func-app/graphrag/index/verbs/text/chunk/strategies/tokens.py new file mode 100644 index 0000000000..6426c783e1 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/strategies/tokens.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and split_text_on_tokens methods definition.""" + +from collections.abc import Iterable +from typing import Any + +import tiktoken +from datashaper import ProgressTicker + +import graphrag.config.defaults as defs +from graphrag.index.text_splitting import Tokenizer +from graphrag.index.verbs.text.chunk.typing import TextChunk + + +def run( + input: list[str], args: dict[str, Any], tick: ProgressTicker +) -> Iterable[TextChunk]: + """Chunks text into multiple parts. A pipeline verb.""" + tokens_per_chunk = args.get("chunk_size", defs.CHUNK_SIZE) + chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP) + encoding_name = args.get("encoding_name", defs.ENCODING_MODEL) + enc = tiktoken.get_encoding(encoding_name) + + def encode(text: str) -> list[int]: + if not isinstance(text, str): + text = f"{text}" + return enc.encode(text) + + def decode(tokens: list[int]) -> str: + return enc.decode(tokens) + + return split_text_on_tokens( + input, + Tokenizer( + chunk_overlap=chunk_overlap, + tokens_per_chunk=tokens_per_chunk, + encode=encode, + decode=decode, + ), + tick, + ) + + +# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471 +# So we could have better control over the chunking process +def split_text_on_tokens( + texts: list[str], enc: Tokenizer, tick: ProgressTicker +) -> list[TextChunk]: + """Split incoming text and return chunks.""" + result = [] + mapped_ids = [] + + for source_doc_idx, text in enumerate(texts): + encoded = enc.encode(text) + tick(1) + mapped_ids.append((source_doc_idx, encoded)) + + input_ids: list[tuple[int, int]] = [ + (source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids + ] + + start_idx = 0 + cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + chunk_text = enc.decode([id for _, id in chunk_ids]) + doc_indices = list({doc_idx for doc_idx, _ in chunk_ids}) + result.append( + TextChunk( + text_chunk=chunk_text, + source_doc_indices=doc_indices, + n_tokens=len(chunk_ids), + ) + ) + start_idx += enc.tokens_per_chunk - enc.chunk_overlap + cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + + return result diff --git a/func-app/graphrag/index/verbs/text/chunk/strategies/typing.py b/func-app/graphrag/index/verbs/text/chunk/strategies/typing.py new file mode 100644 index 0000000000..b4e833c8e3 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/strategies/typing.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing ChunkStrategy definition.""" + +from collections.abc import Callable, Iterable +from typing import Any + +from datashaper import ProgressTicker + +from graphrag.index.verbs.text.chunk.typing import TextChunk + +# Given a list of document texts, return a list of tuples of (source_doc_indices, text_chunk) + +ChunkStrategy = Callable[ + [list[str], dict[str, Any], ProgressTicker], Iterable[TextChunk] +] diff --git a/func-app/graphrag/index/verbs/text/chunk/text_chunk.py b/func-app/graphrag/index/verbs/text/chunk/text_chunk.py new file mode 100644 index 0000000000..40c5578a0f --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/text_chunk.py @@ -0,0 +1,162 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing _get_num_total, chunk, run_strategy and load_strategy methods definitions.""" + +from enum import Enum +from typing import Any, cast + +import pandas as pd +from datashaper import ( + ProgressTicker, + TableContainer, + VerbCallbacks, + VerbInput, + progress_ticker, + verb, +) + +from .strategies.typing import ChunkStrategy as ChunkStrategy +from .typing import ChunkInput + + +def _get_num_total(output: pd.DataFrame, column: str) -> int: + num_total = 0 + for row in output[column]: + if isinstance(row, str): + num_total += 1 + else: + num_total += len(row) + return num_total + + +class ChunkStrategyType(str, Enum): + """ChunkStrategy class definition.""" + + tokens = "tokens" + sentence = "sentence" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="chunk") +def chunk( + input: VerbInput, + column: str, + to: str, + callbacks: VerbCallbacks, + strategy: dict[str, Any] | None = None, + **_kwargs, +) -> TableContainer: + """ + Chunk a piece of text into smaller pieces. + + ## Usage + ```yaml + verb: text_chunk + args: + column: # The name of the column containing the text to chunk, this can either be a column with text, or a column with a list[tuple[doc_id, str]] + to: # The name of the column to output the chunks to + strategy: # The strategy to use to chunk the text, see below for more details + ``` + + ## Strategies + The text chunk verb uses a strategy to chunk the text. The strategy is an object which defines the strategy to use. The following strategies are available: + + ### tokens + This strategy uses the [tokens] library to chunk a piece of text. The strategy config is as follows: + + > Note: In the future, this will likely be renamed to something more generic, like "openai_tokens". + + ```yaml + strategy: + type: tokens + chunk_size: 1200 # Optional, The chunk size to use, default: 1200 + chunk_overlap: 100 # Optional, The chunk overlap to use, default: 100 + ``` + + ### sentence + This strategy uses the nltk library to chunk a piece of text into sentences. The strategy config is as follows: + + ```yaml + strategy: + type: sentence + ``` + """ + if strategy is None: + strategy = {} + output = cast(pd.DataFrame, input.get_input()) + strategy_name = strategy.get("type", ChunkStrategyType.tokens) + strategy_config = {**strategy} + strategy_exec = load_strategy(strategy_name) + + num_total = _get_num_total(output, column) + tick = progress_ticker(callbacks.progress, num_total) + + output[to] = output.apply( + cast( + Any, + lambda x: run_strategy(strategy_exec, x[column], strategy_config, tick), + ), + axis=1, + ) + return TableContainer(table=output) + + +def run_strategy( + strategy: ChunkStrategy, + input: ChunkInput, + strategy_args: dict[str, Any], + tick: ProgressTicker, +) -> list[str | tuple[list[str] | None, str, int]]: + """Run strategy method definition.""" + if isinstance(input, str): + return [item.text_chunk for item in strategy([input], {**strategy_args}, tick)] + + # We can work with both just a list of text content + # or a list of tuples of (document_id, text content) + # text_to_chunk = ''' + texts = [] + for item in input: + if isinstance(item, str): + texts.append(item) + else: + texts.append(item[1]) + + strategy_results = strategy(texts, {**strategy_args}, tick) + + results = [] + for strategy_result in strategy_results: + doc_indices = strategy_result.source_doc_indices + if isinstance(input[doc_indices[0]], str): + results.append(strategy_result.text_chunk) + else: + doc_ids = [input[doc_idx][0] for doc_idx in doc_indices] + results.append(( + doc_ids, + strategy_result.text_chunk, + strategy_result.n_tokens, + )) + return results + + +def load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy: + """Load strategy method definition.""" + match strategy: + case ChunkStrategyType.tokens: + from .strategies.tokens import run as run_tokens + + return run_tokens + case ChunkStrategyType.sentence: + # NLTK + from graphrag.index.bootstrap import bootstrap + + from .strategies.sentence import run as run_sentence + + bootstrap() + return run_sentence + case _: + msg = f"Unknown strategy: {strategy}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/text/chunk/typing.py b/func-app/graphrag/index/verbs/text/chunk/typing.py new file mode 100644 index 0000000000..3a42cf68a7 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/chunk/typing.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'TextChunk' model.""" + +from dataclasses import dataclass + + +@dataclass +class TextChunk: + """Text chunk class definition.""" + + text_chunk: str + source_doc_indices: list[int] + n_tokens: int | None = None + + +ChunkInput = str | list[str] | list[tuple[str, str]] +"""Input to a chunking strategy. Can be a string, a list of strings, or a list of tuples of (id, text).""" diff --git a/func-app/graphrag/index/verbs/text/embed/__init__.py b/func-app/graphrag/index/verbs/text/embed/__init__.py new file mode 100644 index 0000000000..969bd2aab9 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text embed package root.""" + +from .text_embed import TextEmbedStrategyType, text_embed + +__all__ = ["TextEmbedStrategyType", "text_embed"] diff --git a/func-app/graphrag/index/verbs/text/embed/strategies/__init__.py b/func-app/graphrag/index/verbs/text/embed/strategies/__init__.py new file mode 100644 index 0000000000..8cbe7a580e --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/strategies/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine embed strategies package root.""" diff --git a/func-app/graphrag/index/verbs/text/embed/strategies/mock.py b/func-app/graphrag/index/verbs/text/embed/strategies/mock.py new file mode 100644 index 0000000000..1be4ab0f9f --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/strategies/mock.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _embed_text methods definitions.""" + +import random +from collections.abc import Iterable +from typing import Any + +from datashaper import ProgressTicker, VerbCallbacks, progress_ticker + +from graphrag.index.cache import PipelineCache + +from .typing import TextEmbeddingResult + + +async def run( # noqa RUF029 async is required for interface + input: list[str], + callbacks: VerbCallbacks, + cache: PipelineCache, + _args: dict[str, Any], +) -> TextEmbeddingResult: + """Run the Claim extraction chain.""" + input = input if isinstance(input, Iterable) else [input] + ticker = progress_ticker(callbacks.progress, len(input)) + return TextEmbeddingResult( + embeddings=[_embed_text(cache, text, ticker) for text in input] + ) + + +def _embed_text(_cache: PipelineCache, _text: str, tick: ProgressTicker) -> list[float]: + """Embed a single piece of text.""" + tick(1) + return [random.random(), random.random(), random.random()] # noqa S311 diff --git a/func-app/graphrag/index/verbs/text/embed/strategies/openai.py b/func-app/graphrag/index/verbs/text/embed/strategies/openai.py new file mode 100644 index 0000000000..fb443ec83e --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/strategies/openai.py @@ -0,0 +1,181 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run method definition.""" + +import asyncio +import logging +from typing import Any + +import numpy as np +from datashaper import ProgressTicker, VerbCallbacks, progress_ticker + +import graphrag.config.defaults as defs +from graphrag.index.cache import PipelineCache +from graphrag.index.llm import load_llm_embeddings +from graphrag.index.text_splitting import TokenTextSplitter +from graphrag.index.utils import is_null +from graphrag.llm import EmbeddingLLM, OpenAIConfiguration + +from .typing import TextEmbeddingResult + +log = logging.getLogger(__name__) + + +async def run( + input: list[str], + callbacks: VerbCallbacks, + cache: PipelineCache, + args: dict[str, Any], +) -> TextEmbeddingResult: + """Run the Claim extraction chain.""" + if is_null(input): + return TextEmbeddingResult(embeddings=None) + + llm_config = args.get("llm", {}) + batch_size = args.get("batch_size", 16) + batch_max_tokens = args.get("batch_max_tokens", 8191) + oai_config = OpenAIConfiguration(llm_config) + splitter = _get_splitter(oai_config, batch_max_tokens) + llm = _get_llm(oai_config, callbacks, cache) + semaphore: asyncio.Semaphore = asyncio.Semaphore(args.get("num_threads", 4)) + + # Break up the input texts. The sizes here indicate how many snippets are in each input text + texts, input_sizes = _prepare_embed_texts(input, splitter) + text_batches = _create_text_batches( + texts, + batch_size, + batch_max_tokens, + splitter, + ) + log.info( + "embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, max_tokens=%d", + len(input), + len(texts), + len(text_batches), + batch_size, + batch_max_tokens, + ) + ticker = progress_ticker(callbacks.progress, len(text_batches)) + + # Embed each chunk of snippets + embeddings = await _execute(llm, text_batches, ticker, semaphore) + embeddings = _reconstitute_embeddings(embeddings, input_sizes) + + return TextEmbeddingResult(embeddings=embeddings) + + +def _get_splitter( + config: OpenAIConfiguration, batch_max_tokens: int +) -> TokenTextSplitter: + return TokenTextSplitter( + encoding_name=config.encoding_model or defs.ENCODING_MODEL, + chunk_size=batch_max_tokens, + ) + + +def _get_llm( + config: OpenAIConfiguration, + callbacks: VerbCallbacks, + cache: PipelineCache, +) -> EmbeddingLLM: + llm_type = config.lookup("type", "Unknown") + return load_llm_embeddings( + "text_embedding", + llm_type, + callbacks, + cache, + config.raw_config, + ) + + +async def _execute( + llm: EmbeddingLLM, + chunks: list[list[str]], + tick: ProgressTicker, + semaphore: asyncio.Semaphore, +) -> list[list[float]]: + async def embed(chunk: list[str]): + async with semaphore: + chunk_embeddings = await llm(chunk) + result = np.array(chunk_embeddings.output) + tick(1) + return result + + futures = [embed(chunk) for chunk in chunks] + results = await asyncio.gather(*futures) + # merge results in a single list of lists (reduce the collect dimension) + return [item for sublist in results for item in sublist] + + +def _create_text_batches( + texts: list[str], + max_batch_size: int, + max_batch_tokens: int, + splitter: TokenTextSplitter, +) -> list[list[str]]: + """Create batches of texts to embed.""" + # https://learn.microsoft.com/en-us/azure/ai-services/openai/reference + # According to this embeddings reference, Azure limits us to 16 concurrent embeddings and 8191 tokens per request + result = [] + current_batch = [] + current_batch_tokens = 0 + + for text in texts: + token_count = splitter.num_tokens(text) + if ( + len(current_batch) >= max_batch_size + or current_batch_tokens + token_count > max_batch_tokens + ): + result.append(current_batch) + current_batch = [] + current_batch_tokens = 0 + + current_batch.append(text) + current_batch_tokens += token_count + + if len(current_batch) > 0: + result.append(current_batch) + + return result + + +def _prepare_embed_texts( + input: list[str], splitter: TokenTextSplitter +) -> tuple[list[str], list[int]]: + sizes: list[int] = [] + snippets: list[str] = [] + + for text in input: + # Split the input text and filter out any empty content + split_texts = splitter.split_text(text) + if split_texts is None: + continue + split_texts = [text for text in split_texts if len(text) > 0] + + sizes.append(len(split_texts)) + snippets.extend(split_texts) + + return snippets, sizes + + +def _reconstitute_embeddings( + raw_embeddings: list[list[float]], sizes: list[int] +) -> list[list[float] | None]: + """Reconstitute the embeddings into the original input texts.""" + embeddings: list[list[float] | None] = [] + cursor = 0 + for size in sizes: + if size == 0: + embeddings.append(None) + elif size == 1: + embedding = raw_embeddings[cursor] + embeddings.append(embedding) + cursor += 1 + else: + chunk = raw_embeddings[cursor : cursor + size] + average = np.average(chunk, axis=0) + normalized = average / np.linalg.norm(average) + embeddings.append(normalized.tolist()) + cursor += size + return embeddings diff --git a/func-app/graphrag/index/verbs/text/embed/strategies/typing.py b/func-app/graphrag/index/verbs/text/embed/strategies/typing.py new file mode 100644 index 0000000000..1b25256497 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/strategies/typing.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'TextEmbeddingResult' model.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + + +@dataclass +class TextEmbeddingResult: + """Text embedding result class definition.""" + + embeddings: list[list[float] | None] | None + + +TextEmbeddingStrategy = Callable[ + [ + list[str], + VerbCallbacks, + PipelineCache, + dict, + ], + Awaitable[TextEmbeddingResult], +] diff --git a/func-app/graphrag/index/verbs/text/embed/text_embed.py b/func-app/graphrag/index/verbs/text/embed/text_embed.py new file mode 100644 index 0000000000..cd9bbc798d --- /dev/null +++ b/func-app/graphrag/index/verbs/text/embed/text_embed.py @@ -0,0 +1,269 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing text_embed, load_strategy and create_row_from_embedding_data methods definition.""" + +import logging +from enum import Enum +from typing import Any, cast + +import numpy as np +import pandas as pd +from datashaper import TableContainer, VerbCallbacks, VerbInput, verb + +from graphrag.index.cache import PipelineCache +from graphrag.vector_stores import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreFactory, +) + +from .strategies.typing import TextEmbeddingStrategy + +log = logging.getLogger(__name__) + +# Per Azure OpenAI Limits +# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference +DEFAULT_EMBEDDING_BATCH_SIZE = 500 + + +class TextEmbedStrategyType(str, Enum): + """TextEmbedStrategyType class definition.""" + + openai = "openai" + mock = "mock" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="text_embed") +async def text_embed( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + column: str, + strategy: dict, + **kwargs, +) -> TableContainer: + """ + Embed a piece of text into a vector space. The verb outputs a new column containing a mapping between doc_id and vector. + + ## Usage + ```yaml + verb: text_embed + args: + column: text # The name of the column containing the text to embed, this can either be a column with text, or a column with a list[tuple[doc_id, str]] + to: embedding # The name of the column to output the embedding to + strategy: # See strategies section below + ``` + + ## Strategies + The text embed verb uses a strategy to embed the text. The strategy is an object which defines the strategy to use. The following strategies are available: + + ### openai + This strategy uses openai to embed a piece of text. In particular it uses a LLM to embed a piece of text. The strategy config is as follows: + + ```yaml + strategy: + type: openai + llm: # The configuration for the LLM + type: openai_embedding # the type of llm to use, available options are: openai_embedding, azure_openai_embedding + api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai + model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai + max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai + organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai + vector_store: # The optional configuration for the vector store + type: lancedb # The type of vector store to use, available options are: azure_ai_search, lancedb, kusto + <...> + ``` + """ + vector_store_config = strategy.get("vector_store") + + if vector_store_config and not vector_store_config.get("index_in_memory"): + embedding_name = kwargs.get("embedding_name", "default") + vector_name = kwargs.get("vector_name", "vector") + collection_name = _get_collection_name(vector_store_config, embedding_name) + vector_name = _get_collection_name(vector_store_config, vector_name) + vector_store: BaseVectorStore = _create_vector_store( + vector_store_config, collection_name, vector_name, "reports" + ) + vector_store_workflow_config = vector_store_config.get( + embedding_name, vector_store_config + ) + return await _text_embed_with_vector_store( + input, + callbacks, + cache, + column, + strategy, + vector_store, + vector_store_workflow_config, + vector_store_config.get("store_in_table", False), + kwargs.get("to", f"{column}_embedding"), + ) + + return await _text_embed_in_memory( + input, + callbacks, + cache, + column, + strategy, + kwargs.get("to", f"{column}_embedding"), + ) + + +async def _text_embed_in_memory( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + column: str, + strategy: dict, + to: str, +): + output_df = cast(pd.DataFrame, input.get_input()) + strategy_type = strategy["type"] + strategy_exec = load_strategy(strategy_type) + strategy_args = {**strategy} + input_table = input.get_input() + + texts: list[str] = input_table[column].to_numpy().tolist() + result = await strategy_exec(texts, callbacks, cache, strategy_args) + + output_df[to] = result.embeddings + return TableContainer(table=output_df) + + +async def _text_embed_with_vector_store( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + column: str, + strategy: dict[str, Any], + vector_store: BaseVectorStore, + vector_store_config: dict, + store_in_table: bool = False, + to: str = "", +): + output_df = cast(pd.DataFrame, input.get_input()) + strategy_type = strategy["type"] + strategy_exec = load_strategy(strategy_type) + strategy_args = {**strategy} + + # Get vector-storage configuration + insert_batch_size: int = ( + vector_store_config.get("batch_size") or DEFAULT_EMBEDDING_BATCH_SIZE + ) + title_column: str = vector_store_config.get("title_column", "title") + id_column: str = vector_store_config.get("id_column", "id") + overwrite: bool = vector_store_config.get("overwrite", True) + + if column not in output_df.columns: + msg = f"Column {column} not found in input dataframe with columns {output_df.columns}" + raise ValueError(msg) + if title_column not in output_df.columns: + msg = f"Column {title_column} not found in input dataframe with columns {output_df.columns}" + raise ValueError(msg) + if id_column not in output_df.columns: + msg = f"Column {id_column} not found in input dataframe with columns {output_df.columns}" + raise ValueError(msg) + + total_rows = 0 + for row in output_df[column]: + if isinstance(row, list): + total_rows += len(row) + else: + total_rows += 1 + + i = 0 + starting_index = 0 + + all_results = [] + + while insert_batch_size * i < input.get_input().shape[0]: + batch = input.get_input().iloc[ + insert_batch_size * i : insert_batch_size * (i + 1) + ] + texts: list[str] = batch[column].to_numpy().tolist() + titles: list[str] = batch[title_column].to_numpy().tolist() + ids: list[str] = batch[id_column].to_numpy().tolist() + result = await strategy_exec( + texts, + callbacks, + cache, + strategy_args, + ) + if store_in_table and result.embeddings: + embeddings = [ + embedding for embedding in result.embeddings if embedding is not None + ] + all_results.extend(embeddings) + + vectors = result.embeddings or [] + documents: list[VectorStoreDocument] = [] + for id, text, title, vector in zip(ids, texts, titles, vectors, strict=True): + if type(vector) is np.ndarray: + vector = vector.tolist() + document = VectorStoreDocument( + id=id, + text=text, + vector=vector, + attributes={"title": title}, + ) + documents.append(document) + + vector_store.load_documents(documents, overwrite and i == 0) + starting_index += len(documents) + i += 1 + + if store_in_table: + output_df[to] = all_results + + return TableContainer(table=output_df) + + +def _create_vector_store( + vector_store_config: dict, collection_name: str, vector_name: str, reports_name: str, +) -> BaseVectorStore: + vector_store_type: str = str(vector_store_config.get("type")) + if collection_name: + vector_store_config.update({"collection_name": collection_name}) + if vector_name: + vector_store_config.update({"vector_name": vector_name}) + if reports_name: + vector_store_config.update({"reports_name": reports_name}) + + vector_store = VectorStoreFactory.get_vector_store( + vector_store_type, kwargs=vector_store_config + ) + + vector_store.connect(**vector_store_config) + return vector_store + + +def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str: + collection_name = vector_store_config.get("collection_name") + if not collection_name: + collection_names = vector_store_config.get("collection_names", {}) + collection_name = collection_names.get(embedding_name, embedding_name) + + msg = f"using {vector_store_config.get('type')} collection_name {collection_name} for embedding {embedding_name}" + log.info(msg) + return collection_name + + +def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy: + """Load strategy method definition.""" + match strategy: + case TextEmbedStrategyType.openai: + from .strategies.openai import run as run_openai + + return run_openai + case TextEmbedStrategyType.mock: + from .strategies.mock import run as run_mock + + return run_mock + case _: + msg = f"Unknown strategy: {strategy}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/text/replace/__init__.py b/func-app/graphrag/index/verbs/text/replace/__init__.py new file mode 100644 index 0000000000..f863415f40 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/replace/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text replace package root.""" + +from .replace import text_replace + +__all__ = ["text_replace"] diff --git a/func-app/graphrag/index/verbs/text/replace/replace.py b/func-app/graphrag/index/verbs/text/replace/replace.py new file mode 100644 index 0000000000..386fac3459 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/replace/replace.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing replace and _apply_replacements methods.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + +from .typing import Replacement + + +@verb(name="text_replace") +def text_replace( + input: VerbInput, + column: str, + to: str, + replacements: list[dict[str, str]], + **_kwargs: dict, +) -> TableContainer: + """ + Apply a set of replacements to a piece of text. + + ## Usage + ```yaml + verb: text_replace + args: + column: # The name of the column containing the text to replace + to: # The name of the column to write the replaced text to + replacements: # A list of replacements to apply + - pattern: # The regex pattern to find + replacement: # The string to replace with + ``` + """ + output = cast(pd.DataFrame, input.get_input()) + parsed_replacements = [Replacement(**r) for r in replacements] + output[to] = output[column].apply( + lambda text: _apply_replacements(text, parsed_replacements) + ) + return TableContainer(table=output) + + +def _apply_replacements(text: str, replacements: list[Replacement]) -> str: + for r in replacements: + text = text.replace(r.pattern, r.replacement) + return text diff --git a/func-app/graphrag/index/verbs/text/replace/typing.py b/func-app/graphrag/index/verbs/text/replace/typing.py new file mode 100644 index 0000000000..45beef9f28 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/replace/typing.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'Replacement' model.""" + +from dataclasses import dataclass + + +@dataclass +class Replacement: + """Replacement class definition.""" + + pattern: str + replacement: str diff --git a/func-app/graphrag/index/verbs/text/split.py b/func-app/graphrag/index/verbs/text/split.py new file mode 100644 index 0000000000..b1339ff455 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/split.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing the text_split method definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + + +@verb(name="text_split") +def text_split( + input: VerbInput, + column: str, + to: str, + separator: str = ",", + **_kwargs: dict, +) -> TableContainer: + """ + Split a piece of text into a list of strings based on a delimiter. The verb outputs a new column containing a list of strings. + + ## Usage + + ```yaml + verb: text_split + args: + column: text # The name of the column containing the text to split + to: split_text # The name of the column to output the split text to + separator: "," # The separator to split the text on, defaults to "," + ``` + """ + output = text_split_df(cast(pd.DataFrame, input.get_input()), column, to, separator) + return TableContainer(table=output) + + +def text_split_df( + input: pd.DataFrame, column: str, to: str, separator: str = "," +) -> pd.DataFrame: + """Split a column into a list of strings.""" + output = input + + def _apply_split(row): + if row[column] is None or isinstance(row[column], list): + return row[column] + if row[column] == "": + return [] + if not isinstance(row[column], str): + message = f"Expected {column} to be a string, but got {type(row[column])}" + raise TypeError(message) + return row[column].split(separator) + + output[to] = output.apply(_apply_split, axis=1) + return output diff --git a/func-app/graphrag/index/verbs/text/translate/__init__.py b/func-app/graphrag/index/verbs/text/translate/__init__.py new file mode 100644 index 0000000000..ad830dfa87 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine text translate package root.""" + +from .text_translate import text_translate + +__all__ = ["text_translate"] diff --git a/func-app/graphrag/index/verbs/text/translate/strategies/__init__.py b/func-app/graphrag/index/verbs/text/translate/strategies/__init__.py new file mode 100644 index 0000000000..d418bbae28 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/strategies/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine translate strategies package root.""" + +from .mock import run as run_mock +from .openai import run as run_openai + +__all__ = ["run_mock", "run_openai"] diff --git a/func-app/graphrag/index/verbs/text/translate/strategies/defaults.py b/func-app/graphrag/index/verbs/text/translate/strategies/defaults.py new file mode 100644 index 0000000000..003e00eb1f --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/strategies/defaults.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A file containing TRANSLATION_PROMPT value definition.""" + +TRANSLATION_PROMPT = """ + You are a helpful assistant. Translate into {language} the following text, and make sure all of the text is in {language}. + """.strip() diff --git a/func-app/graphrag/index/verbs/text/translate/strategies/mock.py b/func-app/graphrag/index/verbs/text/translate/strategies/mock.py new file mode 100644 index 0000000000..58a5a9995e --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/strategies/mock.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run and _summarize_text methods definitions.""" + +from typing import Any + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + +from .typing import TextTranslationResult + + +async def run( # noqa RUF029 async is required for interface + input: str | list[str], + _args: dict[str, Any], + _reporter: VerbCallbacks, + _cache: PipelineCache, +) -> TextTranslationResult: + """Run the Claim extraction chain.""" + input = [input] if isinstance(input, str) else input + return TextTranslationResult(translations=[_translate_text(text) for text in input]) + + +def _translate_text(text: str) -> str: + """Translate a single piece of text.""" + return f"{text} translated" diff --git a/func-app/graphrag/index/verbs/text/translate/strategies/openai.py b/func-app/graphrag/index/verbs/text/translate/strategies/openai.py new file mode 100644 index 0000000000..49c47b34a2 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/strategies/openai.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run, _translate_text and _create_translation_prompt methods definition.""" + +import logging +import traceback +from typing import Any + +from datashaper import VerbCallbacks + +import graphrag.config.defaults as defs +from graphrag.config.enums import LLMType +from graphrag.index.cache import PipelineCache +from graphrag.index.llm import load_llm +from graphrag.index.text_splitting import TokenTextSplitter +from graphrag.llm import CompletionLLM + +from .defaults import TRANSLATION_PROMPT as DEFAULT_TRANSLATION_PROMPT +from .typing import TextTranslationResult + +log = logging.getLogger(__name__) + + +async def run( + input: str | list[str], + args: dict[str, Any], + callbacks: VerbCallbacks, + pipeline_cache: PipelineCache, +) -> TextTranslationResult: + """Run the Claim extraction chain.""" + llm_config = args.get("llm", {"type": LLMType.StaticResponse}) + llm_type = llm_config.get("type", LLMType.StaticResponse) + llm = load_llm( + "text_translation", + llm_type, + callbacks, + pipeline_cache, + llm_config, + chat_only=True, + ) + language = args.get("language", "English") + prompt = args.get("prompt") + chunk_size = args.get("chunk_size", defs.CHUNK_SIZE) + chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP) + + input = [input] if isinstance(input, str) else input + return TextTranslationResult( + translations=[ + await _translate_text( + text, language, prompt, llm, chunk_size, chunk_overlap, callbacks + ) + for text in input + ] + ) + + +async def _translate_text( + text: str, + language: str, + prompt: str | None, + llm: CompletionLLM, + chunk_size: int, + chunk_overlap: int, + callbacks: VerbCallbacks, +) -> str: + """Translate a single piece of text.""" + splitter = TokenTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + out = "" + chunks = splitter.split_text(text) + for chunk in chunks: + try: + result = await llm( + chunk, + history=[ + { + "role": "system", + "content": (prompt or DEFAULT_TRANSLATION_PROMPT), + } + ], + variables={"language": language}, + ) + out += result.output or "" + except Exception as e: + log.exception("error translating text") + callbacks.error("Error translating text", e, traceback.format_exc()) + out += "" + + return out diff --git a/func-app/graphrag/index/verbs/text/translate/strategies/typing.py b/func-app/graphrag/index/verbs/text/translate/strategies/typing.py new file mode 100644 index 0000000000..d91ed735f5 --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/strategies/typing.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'TextTranslationResult' model.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from datashaper import VerbCallbacks + +from graphrag.index.cache import PipelineCache + + +@dataclass +class TextTranslationResult: + """Text translation result class definition.""" + + translations: list[str] + + +TextTranslationStrategy = Callable[ + [list[str], dict[str, Any], VerbCallbacks, PipelineCache], + Awaitable[TextTranslationResult], +] diff --git a/func-app/graphrag/index/verbs/text/translate/text_translate.py b/func-app/graphrag/index/verbs/text/translate/text_translate.py new file mode 100644 index 0000000000..8d0faffefa --- /dev/null +++ b/func-app/graphrag/index/verbs/text/translate/text_translate.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing text_translate methods definition.""" + +from enum import Enum +from typing import Any, cast + +import pandas as pd +from datashaper import ( + AsyncType, + TableContainer, + VerbCallbacks, + VerbInput, + derive_from_rows, + verb, +) + +from graphrag.index.cache import PipelineCache + +from .strategies.typing import TextTranslationStrategy + + +class TextTranslateStrategyType(str, Enum): + """TextTranslateStrategyType class definition.""" + + openai = "openai" + mock = "mock" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + +@verb(name="text_translate") +async def text_translate( + input: VerbInput, + cache: PipelineCache, + callbacks: VerbCallbacks, + text_column: str, + to: str, + strategy: dict[str, Any], + async_mode: AsyncType = AsyncType.AsyncIO, + **kwargs, +) -> TableContainer: + """ + Translate a piece of text into another language. + + ## Usage + ```yaml + verb: text_translate + args: + text_column: # The name of the column containing the text to translate + to: # The name of the column to write the translated text to + strategy: # The strategy to use to translate the text, see below for more details + ``` + + ## Strategies + The text translate verb uses a strategy to translate the text. The strategy is an object which defines the strategy to use. The following strategies are available: + + ### openai + This strategy uses openai to translate a piece of text. In particular it uses a LLM to translate a piece of text. The strategy config is as follows: + + ```yaml + strategy: + type: openai + language: english # The language to translate to, default: english + prompt: # The prompt to use for the translation, default: None + chunk_size: 2500 # The chunk size to use for the translation, default: 2500 + chunk_overlap: 0 # The chunk overlap to use for the translation, default: 0 + llm: # The configuration for the LLM + type: openai_chat # the type of llm to use, available options are: openai_chat, azure_openai_chat + api_key: !ENV ${GRAPHRAG_OPENAI_API_KEY} # The api key to use for openai + model: !ENV ${GRAPHRAG_OPENAI_MODEL:gpt-4-turbo-preview} # The model to use for openai + max_tokens: !ENV ${GRAPHRAG_MAX_TOKENS:6000} # The max tokens to use for openai + organization: !ENV ${GRAPHRAG_OPENAI_ORGANIZATION} # The organization to use for openai + ``` + """ + output_df = cast(pd.DataFrame, input.get_input()) + strategy_type = strategy["type"] + strategy_args = {**strategy} + strategy_exec = _load_strategy(strategy_type) + + async def run_strategy(row): + text = row[text_column] + result = await strategy_exec(text, strategy_args, callbacks, cache) + + # If it is a single string, then return just the translation for that string + if isinstance(text, str): + return result.translations[0] + + # Otherwise, return a list of translations, one for each item in the input + return list(result.translations) + + results = await derive_from_rows( + output_df, + run_strategy, + callbacks, + scheduling_type=async_mode, + num_threads=kwargs.get("num_threads", 4), + ) + output_df[to] = results + return TableContainer(table=output_df) + + +def _load_strategy(strategy: TextTranslateStrategyType) -> TextTranslationStrategy: + match strategy: + case TextTranslateStrategyType.openai: + from .strategies.openai import run as run_openai + + return run_openai + + case TextTranslateStrategyType.mock: + from .strategies.mock import run as run_mock + + return run_mock + + case _: + msg = f"Unknown strategy: {strategy}" + raise ValueError(msg) diff --git a/func-app/graphrag/index/verbs/unzip.py b/func-app/graphrag/index/verbs/unzip.py new file mode 100644 index 0000000000..4d8c8da08e --- /dev/null +++ b/func-app/graphrag/index/verbs/unzip.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing unzip method definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + + +# TODO: Check if this is already a thing +# Takes 1|(x,y)|b +# and converts to +# 1|x|y|b +@verb(name="unzip") +def unzip( + input: VerbInput, column: str, to: list[str], **_kwargs: dict +) -> TableContainer: + """Unpacks a column containing a tuple into multiple columns.""" + table = cast(pd.DataFrame, input.get_input()) + + table[to] = pd.DataFrame(table[column].tolist(), index=table.index) + + return TableContainer(table=table) diff --git a/func-app/graphrag/index/verbs/zip.py b/func-app/graphrag/index/verbs/zip.py new file mode 100644 index 0000000000..462395d3da --- /dev/null +++ b/func-app/graphrag/index/verbs/zip.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing ds_zip method definition.""" + +from typing import cast + +import pandas as pd +from datashaper import TableContainer, VerbInput, verb + + +@verb(name="zip") +def zip_verb( + input: VerbInput, + to: str, + columns: list[str], + type: str | None = None, # noqa A002 + **_kwargs: dict, +) -> TableContainer: + """ + Zip columns together. + + ## Usage + TODO + + """ + table = cast(pd.DataFrame, input.get_input()) + if type is None: + table[to] = list(zip(*[table[col] for col in columns], strict=True)) + + # This one is a little weird + elif type == "dict": + if len(columns) != 2: + msg = f"Expected exactly two columns for a dict, got {columns}" + raise ValueError(msg) + key_col, value_col = columns + + results = [] + for _, row in table.iterrows(): + keys = row[key_col] + values = row[value_col] + output = {} + if len(keys) != len(values): + msg = f"Expected same number of keys and values, got {len(keys)} keys and {len(values)} values" + raise ValueError(msg) + for idx, key in enumerate(keys): + output[key] = values[idx] + results.append(output) + + table[to] = results + return TableContainer(table=table.reset_index(drop=True)) diff --git a/func-app/graphrag/index/workflows/__init__.py b/func-app/graphrag/index/workflows/__init__.py new file mode 100644 index 0000000000..ed580309a8 --- /dev/null +++ b/func-app/graphrag/index/workflows/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine workflows package root.""" + +from .load import create_workflow, load_workflows +from .typing import ( + StepDefinition, + VerbDefinitions, + VerbTiming, + WorkflowConfig, + WorkflowDefinitions, + WorkflowToRun, +) + +__all__ = [ + "StepDefinition", + "VerbDefinitions", + "VerbTiming", + "WorkflowConfig", + "WorkflowDefinitions", + "WorkflowToRun", + "create_workflow", + "load_workflows", +] diff --git a/func-app/graphrag/index/workflows/default_workflows.py b/func-app/graphrag/index/workflows/default_workflows.py new file mode 100644 index 0000000000..81112bee32 --- /dev/null +++ b/func-app/graphrag/index/workflows/default_workflows.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing default workflows definitions.""" + +from .typing import WorkflowDefinitions +from .v1.create_base_documents import ( + build_steps as build_create_base_documents_steps, +) +from .v1.create_base_documents import ( + workflow_name as create_base_documents, +) +from .v1.create_base_entity_graph import ( + build_steps as build_create_base_entity_graph_steps, +) +from .v1.create_base_entity_graph import ( + workflow_name as create_base_entity_graph, +) +from .v1.create_base_extracted_entities import ( + build_steps as build_create_base_extracted_entities_steps, +) +from .v1.create_base_extracted_entities import ( + workflow_name as create_base_extracted_entities, +) +from .v1.create_base_text_units import ( + build_steps as build_create_base_text_units_steps, +) +from .v1.create_base_text_units import ( + workflow_name as create_base_text_units, +) +from .v1.create_final_communities import ( + build_steps as build_create_final_communities_steps, +) +from .v1.create_final_communities import ( + workflow_name as create_final_communities, +) +from .v1.create_final_community_reports import ( + build_steps as build_create_final_community_reports_steps, +) +from .v1.create_final_community_reports import ( + workflow_name as create_final_community_reports, +) +from .v1.create_final_covariates import ( + build_steps as build_create_final_covariates_steps, +) +from .v1.create_final_covariates import ( + workflow_name as create_final_covariates, +) +from .v1.create_final_documents import ( + build_steps as build_create_final_documents_steps, +) +from .v1.create_final_documents import ( + workflow_name as create_final_documents, +) +from .v1.create_final_entities import ( + build_steps as build_create_final_entities_steps, +) +from .v1.create_final_entities import ( + workflow_name as create_final_entities, +) +from .v1.create_final_nodes import ( + build_steps as build_create_final_nodes_steps, +) +from .v1.create_final_nodes import ( + workflow_name as create_final_nodes, +) +from .v1.create_final_relationships import ( + build_steps as build_create_final_relationships_steps, +) +from .v1.create_final_relationships import ( + workflow_name as create_final_relationships, +) +from .v1.create_final_text_units import ( + build_steps as build_create_final_text_units, +) +from .v1.create_final_text_units import ( + workflow_name as create_final_text_units, +) +from .v1.create_summarized_entities import ( + build_steps as build_create_summarized_entities_steps, +) +from .v1.create_summarized_entities import ( + workflow_name as create_summarized_entities, +) +from .v1.join_text_units_to_covariate_ids import ( + build_steps as join_text_units_to_covariate_ids_steps, +) +from .v1.join_text_units_to_covariate_ids import ( + workflow_name as join_text_units_to_covariate_ids, +) +from .v1.join_text_units_to_entity_ids import ( + build_steps as join_text_units_to_entity_ids_steps, +) +from .v1.join_text_units_to_entity_ids import ( + workflow_name as join_text_units_to_entity_ids, +) +from .v1.join_text_units_to_relationship_ids import ( + build_steps as join_text_units_to_relationship_ids_steps, +) +from .v1.join_text_units_to_relationship_ids import ( + workflow_name as join_text_units_to_relationship_ids, +) + +default_workflows: WorkflowDefinitions = { + create_base_extracted_entities: build_create_base_extracted_entities_steps, + create_base_entity_graph: build_create_base_entity_graph_steps, + create_base_text_units: build_create_base_text_units_steps, + create_final_text_units: build_create_final_text_units, + create_final_community_reports: build_create_final_community_reports_steps, + create_final_nodes: build_create_final_nodes_steps, + create_final_relationships: build_create_final_relationships_steps, + create_final_documents: build_create_final_documents_steps, + create_final_covariates: build_create_final_covariates_steps, + create_base_documents: build_create_base_documents_steps, + create_final_entities: build_create_final_entities_steps, + create_final_communities: build_create_final_communities_steps, + create_summarized_entities: build_create_summarized_entities_steps, + join_text_units_to_entity_ids: join_text_units_to_entity_ids_steps, + join_text_units_to_covariate_ids: join_text_units_to_covariate_ids_steps, + join_text_units_to_relationship_ids: join_text_units_to_relationship_ids_steps, +} diff --git a/func-app/graphrag/index/workflows/load.py b/func-app/graphrag/index/workflows/load.py new file mode 100644 index 0000000000..4dd6f9bfd0 --- /dev/null +++ b/func-app/graphrag/index/workflows/load.py @@ -0,0 +1,171 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing load_workflows, create_workflow, _get_steps_for_workflow and _remove_disabled_steps methods definition.""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, NamedTuple, cast + +from datashaper import Workflow + +from graphrag.index.errors import ( + NoWorkflowsDefinedError, + UndefinedWorkflowError, + UnknownWorkflowError, +) +from graphrag.index.utils import topological_sort + +from .default_workflows import default_workflows as _default_workflows +from .typing import VerbDefinitions, WorkflowDefinitions, WorkflowToRun + +if TYPE_CHECKING: + from graphrag.index.config import ( + PipelineWorkflowConfig, + PipelineWorkflowReference, + PipelineWorkflowStep, + ) + +anonymous_workflow_count = 0 + +VerbFn = Callable[..., Any] +log = logging.getLogger(__name__) + + +class LoadWorkflowResult(NamedTuple): + """A workflow loading result object.""" + + workflows: list[WorkflowToRun] + """The loaded workflow names in the order they should be run.""" + + dependencies: dict[str, list[str]] + """A dictionary of workflow name to workflow dependencies.""" + + +def load_workflows( + workflows_to_load: list[PipelineWorkflowReference], + additional_verbs: VerbDefinitions | None = None, + additional_workflows: WorkflowDefinitions | None = None, + memory_profile: bool = False, +) -> LoadWorkflowResult: + """Load the given workflows. + + Args: + - workflows_to_load - The workflows to load + - additional_verbs - The list of custom verbs available to the workflows + - additional_workflows - The list of custom workflows + Returns: + - output[0] - The loaded workflow names in the order they should be run + - output[1] - A dictionary of workflow name to workflow dependencies + """ + workflow_graph: dict[str, WorkflowToRun] = {} + + global anonymous_workflow_count + for reference in workflows_to_load: + name = reference.name + is_anonymous = name is None or name.strip() == "" + if is_anonymous: + name = f"Anonymous Workflow {anonymous_workflow_count}" + anonymous_workflow_count += 1 + name = cast(str, name) + + config = reference.config + workflow = create_workflow( + name or "MISSING NAME!", + reference.steps, + config, + additional_verbs, + additional_workflows, + ) + workflow_graph[name] = WorkflowToRun(workflow, config=config or {}) + + # Backfill any missing workflows + for name in list(workflow_graph.keys()): + workflow = workflow_graph[name] + deps = [ + d.replace("workflow:", "") + for d in workflow.workflow.dependencies + if d.startswith("workflow:") + ] + for dependency in deps: + if dependency not in workflow_graph: + reference = {"name": dependency, **workflow.config} + workflow_graph[dependency] = WorkflowToRun( + workflow=create_workflow( + dependency, + config=reference, + additional_verbs=additional_verbs, + additional_workflows=additional_workflows, + memory_profile=memory_profile, + ), + config=reference, + ) + + # Run workflows in order of dependencies + def filter_wf_dependencies(name: str) -> list[str]: + externals = [ + e.replace("workflow:", "") + for e in workflow_graph[name].workflow.dependencies + ] + return [e for e in externals if e in workflow_graph] + + task_graph = {name: filter_wf_dependencies(name) for name in workflow_graph} + workflow_run_order = topological_sort(task_graph) + workflows = [workflow_graph[name] for name in workflow_run_order] + log.info("Workflow Run Order: %s", workflow_run_order) + return LoadWorkflowResult(workflows=workflows, dependencies=task_graph) + + +def create_workflow( + name: str, + steps: list[PipelineWorkflowStep] | None = None, + config: PipelineWorkflowConfig | None = None, + additional_verbs: VerbDefinitions | None = None, + additional_workflows: WorkflowDefinitions | None = None, + memory_profile: bool = False, +) -> Workflow: + """Create a workflow from the given config.""" + additional_workflows = { + **_default_workflows, + **(additional_workflows or {}), + } + steps = steps or _get_steps_for_workflow(name, config, additional_workflows) + steps = _remove_disabled_steps(steps) + return Workflow( + verbs=additional_verbs or {}, + schema={ + "name": name, + "steps": steps, + }, + validate=False, + memory_profile=memory_profile, + ) + + +def _get_steps_for_workflow( + name: str | None, + config: PipelineWorkflowConfig | None, + workflows: dict[str, Callable] | None, +) -> list[PipelineWorkflowStep]: + """Get the steps for the given workflow config.""" + if config is not None and "steps" in config: + return config["steps"] + + if workflows is None: + raise NoWorkflowsDefinedError + + if name is None: + raise UndefinedWorkflowError + + if name not in workflows: + raise UnknownWorkflowError(name) + + return workflows[name](config or {}) + + +def _remove_disabled_steps( + steps: list[PipelineWorkflowStep], +) -> list[PipelineWorkflowStep]: + return [step for step in steps if step.get("enabled", True)] diff --git a/func-app/graphrag/index/workflows/typing.py b/func-app/graphrag/index/workflows/typing.py new file mode 100644 index 0000000000..3b44545bd4 --- /dev/null +++ b/func-app/graphrag/index/workflows/typing.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'WorkflowToRun' model.""" + +from collections.abc import Callable +from dataclasses import dataclass as dc_dataclass +from typing import Any + +from datashaper import TableContainer, Workflow + +StepDefinition = dict[str, Any] +"""A step definition.""" + +VerbDefinitions = dict[str, Callable[..., TableContainer]] +"""A mapping of verb names to their implementations.""" + +WorkflowConfig = dict[str, Any] +"""A workflow configuration.""" + +WorkflowDefinitions = dict[str, Callable[[WorkflowConfig], list[StepDefinition]]] +"""A mapping of workflow names to their implementations.""" + +VerbTiming = dict[str, float] +"""The timings of verbs by id.""" + + +@dc_dataclass +class WorkflowToRun: + """Workflow to run class definition.""" + + workflow: Workflow + config: dict[str, Any] diff --git a/func-app/graphrag/index/workflows/v1/__init__.py b/func-app/graphrag/index/workflows/v1/__init__.py new file mode 100644 index 0000000000..69518f5ee2 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine workflows package root.""" diff --git a/func-app/graphrag/index/workflows/v1/create_base_documents.py b/func-app/graphrag/index/workflows/v1/create_base_documents.py new file mode 100644 index 0000000000..bd7094c64a --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_base_documents.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from datashaper import DEFAULT_INPUT_NAME + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_base_documents" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the documents table. + + ## Dependencies + * `workflow:create_final_text_units` + """ + document_attribute_columns = config.get("document_attribute_columns", []) + return [ + { + "verb": "unroll", + "args": {"column": "document_ids"}, + "input": {"source": "workflow:create_final_text_units"}, + }, + { + "verb": "select", + "args": { + # We only need the chunk id and the document id + "columns": ["id", "document_ids", "text"] + }, + }, + { + "id": "rename_chunk_doc_id", + "verb": "rename", + "args": { + "columns": { + "document_ids": "chunk_doc_id", + "id": "chunk_id", + "text": "chunk_text", + } + }, + }, + { + "verb": "join", + "args": { + # Join the doc id from the chunk onto the original document + "on": ["chunk_doc_id", "id"] + }, + "input": {"source": "rename_chunk_doc_id", "others": [DEFAULT_INPUT_NAME]}, + }, + { + "id": "docs_with_text_units", + "verb": "aggregate_override", + "args": { + "groupby": ["id"], + "aggregations": [ + { + "column": "chunk_id", + "operation": "array_agg", + "to": "text_units", + } + ], + }, + }, + { + "verb": "join", + "args": { + "on": ["id", "id"], + "strategy": "right outer", + }, + "input": { + "source": "docs_with_text_units", + "others": [DEFAULT_INPUT_NAME], + }, + }, + { + "verb": "rename", + "args": {"columns": {"text": "raw_content"}}, + }, + *[ + { + "verb": "convert", + "args": { + "column": column, + "to": column, + "type": "string", + }, + } + for column in document_attribute_columns + ], + { + "verb": "merge_override", + "enabled": len(document_attribute_columns) > 0, + "args": { + "columns": document_attribute_columns, + "strategy": "json", + "to": "attributes", + }, + }, + {"verb": "convert", "args": {"column": "id", "to": "id", "type": "string"}}, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_base_entity_graph.py b/func-app/graphrag/index/workflows/v1/create_base_entity_graph.py new file mode 100644 index 0000000000..b001aad218 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_base_entity_graph.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_base_entity_graph" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for the entity graph. + + ## Dependencies + * `workflow:create_base_extracted_entities` + """ + clustering_config = config.get( + "cluster_graph", + {"strategy": {"type": "leiden"}}, + ) + embed_graph_config = config.get( + "embed_graph", + { + "strategy": { + "type": "node2vec", + "num_walks": config.get("embed_num_walks", 10), + "walk_length": config.get("embed_walk_length", 40), + "window_size": config.get("embed_window_size", 2), + "iterations": config.get("embed_iterations", 3), + "random_seed": config.get("embed_random_seed", 86), + } + }, + ) + + graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False + embed_graph_enabled = config.get("embed_graph_enabled", False) or False + + return [ + { + "verb": "cluster_graph", + "args": { + **clustering_config, + "column": "entity_graph", + "to": "clustered_graph", + "level_to": "level", + }, + "input": ({"source": "workflow:create_summarized_entities"}), + }, + { + "verb": "snapshot_rows", + "enabled": graphml_snapshot_enabled, + "args": { + "base_name": "clustered_graph", + "column": "clustered_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + }, + { + "verb": "embed_graph", + "enabled": embed_graph_enabled, + "args": { + "column": "clustered_graph", + "to": "embeddings", + **embed_graph_config, + }, + }, + { + "verb": "snapshot_rows", + "enabled": graphml_snapshot_enabled, + "args": { + "base_name": "embedded_graph", + "column": "entity_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + }, + { + "verb": "select", + "args": { + # only selecting for documentation sake, so we know what is contained in + # this workflow + "columns": ( + ["level", "clustered_graph", "embeddings"] + if embed_graph_enabled + else ["level", "clustered_graph"] + ), + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_base_extracted_entities.py b/func-app/graphrag/index/workflows/v1/create_base_extracted_entities.py new file mode 100644 index 0000000000..30d608e9fd --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_base_extracted_entities.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from datashaper import AsyncType + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_base_extracted_entities" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for extracted entities. + + ## Dependencies + * `workflow:create_base_text_units` + """ + entity_extraction_config = config.get("entity_extract", {}) + graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False + raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False + + return [ + { + "verb": "entity_extract", + "args": { + **entity_extraction_config, + "column": entity_extraction_config.get("text_column", "chunk"), + "id_column": entity_extraction_config.get("id_column", "chunk_id"), + "async_mode": entity_extraction_config.get( + "async_mode", AsyncType.AsyncIO + ), + "to": "entities", + "graph_to": "entity_graph", + }, + "input": {"source": "workflow:create_base_text_units"}, + }, + { + "verb": "snapshot", + "enabled": raw_entity_snapshot_enabled, + "args": { + "name": "raw_extracted_entities", + "formats": ["json"], + }, + }, + { + "verb": "merge_graphs", + "args": { + "column": "entity_graph", + "to": "entity_graph", + **config.get( + "graph_merge_operations", + { + "nodes": { + "source_id": { + "operation": "concat", + "delimiter": ", ", + "distinct": True, + }, + "description": ({ + "operation": "concat", + "separator": "\n", + "distinct": False, + }), + }, + "edges": { + "source_id": { + "operation": "concat", + "delimiter": ", ", + "distinct": True, + }, + "description": ({ + "operation": "concat", + "separator": "\n", + "distinct": False, + }), + "weight": "sum", + }, + }, + ), + }, + }, + { + "verb": "snapshot_rows", + "enabled": graphml_snapshot_enabled, + "args": { + "base_name": "merged_graph", + "column": "entity_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_base_text_units.py b/func-app/graphrag/index/workflows/v1/create_base_text_units.py new file mode 100644 index 0000000000..63876e5e49 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_base_text_units.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from datashaper import DEFAULT_INPUT_NAME + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_base_text_units" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for text units. + + ## Dependencies + None + """ + chunk_column_name = config.get("chunk_column", "chunk") + chunk_by_columns = config.get("chunk_by", []) or [] + n_tokens_column_name = config.get("n_tokens_column", "n_tokens") + return [ + { + "verb": "orderby", + "args": { + "orders": [ + # sort for reproducibility + {"column": "id", "direction": "asc"}, + ] + }, + "input": {"source": DEFAULT_INPUT_NAME}, + }, + { + "verb": "zip", + "args": { + # Pack the document ids with the text + # So when we unpack the chunks, we can restore the document id + "columns": ["id", "text"], + "to": "text_with_ids", + }, + }, + { + "verb": "aggregate_override", + "args": { + "groupby": [*chunk_by_columns] if len(chunk_by_columns) > 0 else None, + "aggregations": [ + { + "column": "text_with_ids", + "operation": "array_agg", + "to": "texts", + } + ], + }, + }, + { + "verb": "chunk", + "args": {"column": "texts", "to": "chunks", **config.get("text_chunk", {})}, + }, + { + "verb": "select", + "args": { + "columns": [*chunk_by_columns, "chunks"], + }, + }, + { + "verb": "unroll", + "args": { + "column": "chunks", + }, + }, + { + "verb": "rename", + "args": { + "columns": { + "chunks": chunk_column_name, + } + }, + }, + { + "verb": "genid", + "args": { + # Generate a unique id for each chunk + "to": "chunk_id", + "method": "md5_hash", + "hash": [chunk_column_name], + }, + }, + { + "verb": "unzip", + "args": { + "column": chunk_column_name, + "to": ["document_ids", chunk_column_name, n_tokens_column_name], + }, + }, + {"verb": "copy", "args": {"column": "chunk_id", "to": "id"}}, + { + # ELIMINATE EMPTY CHUNKS + "verb": "filter", + "args": { + "column": chunk_column_name, + "criteria": [ + { + "type": "value", + "operator": "is not empty", + } + ], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_communities.py b/func-app/graphrag/index/workflows/v1/create_final_communities.py new file mode 100644 index 0000000000..f8949dfcec --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_communities.py @@ -0,0 +1,172 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_communities" + + +def build_steps( + _config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final communities table. + + ## Dependencies + * `workflow:create_base_entity_graph` + """ + return [ + { + "id": "graph_nodes", + "verb": "unpack_graph", + "args": { + "column": "clustered_graph", + "type": "nodes", + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, + { + "id": "graph_edges", + "verb": "unpack_graph", + "args": { + "column": "clustered_graph", + "type": "edges", + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, + { + "id": "source_clusters", + "verb": "join", + "args": { + "on": ["label", "source"], + }, + "input": {"source": "graph_nodes", "others": ["graph_edges"]}, + }, + { + "id": "target_clusters", + "verb": "join", + "args": { + "on": ["label", "target"], + }, + "input": {"source": "graph_nodes", "others": ["graph_edges"]}, + }, + { + "id": "concatenated_clusters", + "verb": "concat", + "input": { + "source": "source_clusters", + "others": ["target_clusters"], + }, + }, + { + "id": "combined_clusters", + "verb": "filter", + "args": { + # level_1 is the left side of the join + # level_2 is the right side of the join + "column": "level_1", + "criteria": [ + {"type": "column", "operator": "equals", "value": "level_2"} + ], + }, + "input": {"source": "concatenated_clusters"}, + }, + { + "id": "cluster_relationships", + "verb": "aggregate_override", + "args": { + "groupby": [ + "cluster", + "level_1", # level_1 is the left side of the join + ], + "aggregations": [ + { + "column": "id_2", # this is the id of the edge from the join steps above + "to": "relationship_ids", + "operation": "array_agg_distinct", + }, + { + "column": "source_id_1", + "to": "text_unit_ids", + "operation": "array_agg_distinct", + }, + ], + }, + "input": {"source": "combined_clusters"}, + }, + { + "id": "all_clusters", + "verb": "aggregate_override", + "args": { + "groupby": ["cluster", "level"], + "aggregations": [{"column": "cluster", "to": "id", "operation": "any"}], + }, + "input": {"source": "graph_nodes"}, + }, + { + "verb": "join", + "args": { + "on": ["id", "cluster"], + }, + "input": {"source": "all_clusters", "others": ["cluster_relationships"]}, + }, + { + "verb": "filter", + "args": { + # level is the left side of the join + # level_1 is the right side of the join + "column": "level", + "criteria": [ + {"type": "column", "operator": "equals", "value": "level_1"} + ], + }, + }, + *create_community_title_wf, + { + # TODO: Rodrigo says "raw_community" is temporary + "verb": "copy", + "args": { + "column": "id", + "to": "raw_community", + }, + }, + { + "verb": "select", + "args": { + "columns": [ + "id", + "title", + "level", + "raw_community", + "relationship_ids", + "text_unit_ids", + ], + }, + }, + ] + + +create_community_title_wf = [ + # Hack to string concat "Community " + id + { + "verb": "fill", + "args": { + "to": "__temp", + "value": "Community ", + }, + }, + { + "verb": "merge", + "args": { + "columns": [ + "__temp", + "id", + ], + "to": "title", + "strategy": "concat", + "preserveSource": True, + }, + }, +] diff --git a/func-app/graphrag/index/workflows/v1/create_final_community_reports.py b/func-app/graphrag/index/workflows/v1/create_final_community_reports.py new file mode 100644 index 0000000000..164c70e0dd --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_community_reports.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_community_reports" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final community reports table. + + ## Dependencies + * `workflow:create_base_entity_graph` + """ + covariates_enabled = config.get("covariates_enabled", False) + create_community_reports_config = config.get("create_community_reports", {}) + base_text_embed = config.get("text_embed", {}) + community_report_full_content_embed_config = config.get( + "community_report_full_content_embed", base_text_embed + ) + community_report_summary_embed_config = config.get( + "community_report_summary_embed", base_text_embed + ) + community_report_title_embed_config = config.get( + "community_report_title_embed", base_text_embed + ) + skip_title_embedding = config.get("skip_title_embedding", False) + skip_summary_embedding = config.get("skip_summary_embedding", False) + skip_full_content_embedding = config.get("skip_full_content_embedding", False) + + return [ + # + # Subworkflow: Prepare Nodes + # + { + "id": "nodes", + "verb": "prepare_community_reports_nodes", + "input": {"source": "workflow:create_final_nodes"}, + }, + # + # Subworkflow: Prepare Edges + # + { + "id": "edges", + "verb": "prepare_community_reports_edges", + "input": {"source": "workflow:create_final_relationships"}, + }, + # + # Subworkflow: Prepare Claims Table + # + { + "id": "claims", + "enabled": covariates_enabled, + "verb": "prepare_community_reports_claims", + "input": { + "source": "workflow:create_final_covariates", + } + if covariates_enabled + else {}, + }, + # + # Subworkflow: Get Community Hierarchy + # + { + "id": "community_hierarchy", + "verb": "restore_community_hierarchy", + "input": {"source": "nodes"}, + }, + # + # Main Workflow: Create Community Reports + # + { + "id": "local_contexts", + "verb": "prepare_community_reports", + "input": { + "source": "nodes", + "nodes": "nodes", + "edges": "edges", + **({"claims": "claims"} if covariates_enabled else {}), + }, + }, + { + "verb": "create_community_reports", + "args": { + **create_community_reports_config, + }, + "input": { + "source": "local_contexts", + "community_hierarchy": "community_hierarchy", + "nodes": "nodes", + }, + }, + { + # Generate a unique ID for each community report distinct from the community ID + "verb": "window", + "args": {"to": "id", "operation": "uuid", "column": "community"}, + }, + { + "verb": "text_embed", + "enabled": not skip_full_content_embedding, + "args": { + "embedding_name": "community_report_full_content", + "column": "full_content", + "to": "full_content_embedding", + **community_report_full_content_embed_config, + }, + }, + { + "verb": "text_embed", + "enabled": not skip_summary_embedding, + "args": { + "embedding_name": "community_report_summary", + "column": "summary", + "to": "summary_embedding", + **community_report_summary_embed_config, + }, + }, + { + "verb": "text_embed", + "enabled": not skip_title_embedding, + "args": { + "embedding_name": "community_report_title", + "column": "title", + "to": "title_embedding", + **community_report_title_embed_config, + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_covariates.py b/func-app/graphrag/index/workflows/v1/create_final_covariates.py new file mode 100644 index 0000000000..d1090e5054 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_covariates.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from datashaper import AsyncType + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_covariates" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final covariates table. + + ## Dependencies + * `workflow:create_base_text_units` + * `workflow:create_base_extracted_entities` + """ + claim_extract_config = config.get("claim_extract", {}) + + input = {"source": "workflow:create_base_text_units"} + + return [ + { + "verb": "extract_covariates", + "args": { + "column": config.get("chunk_column", "chunk"), + "id_column": config.get("chunk_id_column", "chunk_id"), + "resolved_entities_column": "resolved_entities", + "covariate_type": "claim", + "async_mode": config.get("async_mode", AsyncType.AsyncIO), + **claim_extract_config, + }, + "input": input, + }, + { + "verb": "window", + "args": {"to": "id", "operation": "uuid", "column": "covariate_type"}, + }, + { + "verb": "genid", + "args": { + "to": "human_readable_id", + "method": "increment", + }, + }, + { + "verb": "convert", + "args": { + "column": "human_readable_id", + "type": "string", + "to": "human_readable_id", + }, + }, + { + "verb": "rename", + "args": { + "columns": { + "chunk_id": "text_unit_id", + } + }, + }, + { + "verb": "select", + "args": { + "columns": [ + "id", + "human_readable_id", + "covariate_type", + "type", + "description", + "subject_id", + "subject_type", + "object_id", + "object_type", + "status", + "start_date", + "end_date", + "source_text", + "text_unit_id", + "document_ids", + "n_tokens", + ] + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_documents.py b/func-app/graphrag/index/workflows/v1/create_final_documents.py new file mode 100644 index 0000000000..d09ce001b0 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_documents.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_documents" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final documents table. + + ## Dependencies + * `workflow:create_base_documents` + * `workflow:create_base_document_nodes` + """ + base_text_embed = config.get("text_embed", {}) + document_raw_content_embed_config = config.get( + "document_raw_content_embed", base_text_embed + ) + skip_raw_content_embedding = config.get("skip_raw_content_embedding", False) + return [ + { + "verb": "rename", + "args": {"columns": {"text_units": "text_unit_ids"}}, + "input": {"source": "workflow:create_base_documents"}, + }, + { + "verb": "text_embed", + "enabled": not skip_raw_content_embedding, + "args": { + "column": "raw_content", + "to": "raw_content_embedding", + **document_raw_content_embed_config, + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_entities.py b/func-app/graphrag/index/workflows/v1/create_final_entities.py new file mode 100644 index 0000000000..9d8b962b77 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_entities.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_entities" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final entities table. + + ## Dependencies + * `workflow:create_base_entity_graph` + """ + base_text_embed = config.get("text_embed", {}) + entity_name_embed_config = config.get("entity_name_embed", base_text_embed) + entity_name_description_embed_config = config.get( + "entity_name_description_embed", base_text_embed + ) + skip_name_embedding = config.get("skip_name_embedding", False) + skip_description_embedding = config.get("skip_description_embedding", False) + is_using_vector_store = ( + entity_name_embed_config.get("strategy", {}).get("vector_store", None) + is not None + ) + + return [ + { + "verb": "unpack_graph", + "args": { + "column": "clustered_graph", + "type": "nodes", + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, + {"verb": "rename", "args": {"columns": {"label": "title"}}}, + { + "verb": "select", + "args": { + "columns": [ + "id", + "title", + "type", + "description", + "human_readable_id", + "graph_embedding", + "source_id", + ], + }, + }, + { + # create_base_entity_graph has multiple levels of clustering, which means there are multiple graphs with the same entities + # this dedupes the entities so that there is only one of each entity + "verb": "dedupe", + "args": {"columns": ["id"]}, + }, + {"verb": "rename", "args": {"columns": {"title": "name"}}}, + { + # ELIMINATE EMPTY NAMES + "verb": "filter", + "args": { + "column": "name", + "criteria": [ + { + "type": "value", + "operator": "is not empty", + } + ], + }, + }, + { + "verb": "text_split", + "args": {"separator": ",", "column": "source_id", "to": "text_unit_ids"}, + }, + {"verb": "drop", "args": {"columns": ["source_id"]}}, + { + "verb": "text_embed", + "enabled": not skip_name_embedding, + "args": { + "embedding_name": "entity_name", + "column": "name", + "to": "name_embedding", + **entity_name_embed_config, + }, + }, + { + "verb": "merge", + "enabled": not skip_description_embedding, + "args": { + "strategy": "concat", + "columns": ["name", "description"], + "to": "name_description", + "delimiter": ":", + "preserveSource": True, + }, + }, + { + "verb": "text_embed", + "enabled": not skip_description_embedding, + "args": { + "embedding_name": "entity_name_description", + "column": "name_description", + "to": "description_embedding", + **entity_name_description_embed_config, + }, + }, + { + "verb": "drop", + "enabled": not skip_description_embedding, + "args": { + "columns": ["name_description"], + }, + }, + { + # ELIMINATE EMPTY DESCRIPTION EMBEDDINGS + "verb": "filter", + "enabled": not skip_description_embedding and not is_using_vector_store, + "args": { + "column": "description_embedding", + "criteria": [ + { + "type": "value", + "operator": "is not empty", + } + ], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_nodes.py b/func-app/graphrag/index/workflows/v1/create_final_nodes.py new file mode 100644 index 0000000000..31277e7bf0 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_nodes.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_nodes" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for the document graph. + + ## Dependencies + * `workflow:create_base_entity_graph` + """ + snapshot_top_level_nodes = config.get("snapshot_top_level_nodes", False) + layout_graph_enabled = config.get("layout_graph_enabled", True) + _compute_top_level_node_positions = [ + { + "verb": "unpack_graph", + "args": {"column": "positioned_graph", "type": "nodes"}, + "input": {"source": "laid_out_entity_graph"}, + }, + { + "verb": "filter", + "args": { + "column": "level", + "criteria": [ + { + "type": "value", + "operator": "equals", + "value": config.get("level_for_node_positions", 0), + } + ], + }, + }, + { + "verb": "select", + "args": {"columns": ["id", "x", "y"]}, + }, + { + "verb": "snapshot", + "enabled": snapshot_top_level_nodes, + "args": { + "name": "top_level_nodes", + "formats": ["json"], + }, + }, + { + "id": "_compute_top_level_node_positions", + "verb": "rename", + "args": { + "columns": { + "id": "top_level_node_id", + } + }, + }, + { + "verb": "convert", + "args": { + "column": "top_level_node_id", + "to": "top_level_node_id", + "type": "string", + }, + }, + ] + layout_graph_config = config.get( + "layout_graph", + { + "strategy": { + "type": "umap" if layout_graph_enabled else "zero", + }, + }, + ) + return [ + { + "id": "laid_out_entity_graph", + "verb": "layout_graph", + "args": { + "embeddings_column": "embeddings", + "graph_column": "clustered_graph", + "to": "node_positions", + "graph_to": "positioned_graph", + **layout_graph_config, + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, + { + "verb": "unpack_graph", + "args": {"column": "positioned_graph", "type": "nodes"}, + }, + { + "id": "nodes_without_positions", + "verb": "drop", + "args": {"columns": ["x", "y"]}, + }, + *_compute_top_level_node_positions, + { + "verb": "join", + "args": { + "on": ["id", "top_level_node_id"], + }, + "input": { + "source": "nodes_without_positions", + "others": ["_compute_top_level_node_positions"], + }, + }, + { + "verb": "rename", + "args": {"columns": {"label": "title", "cluster": "community"}}, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_relationships.py b/func-app/graphrag/index/workflows/v1/create_final_relationships.py new file mode 100644 index 0000000000..a58c2a45b4 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_relationships.py @@ -0,0 +1,94 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_relationships" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final relationships table. + + ## Dependencies + * `workflow:create_base_entity_graph` + """ + base_text_embed = config.get("text_embed", {}) + relationship_description_embed_config = config.get( + "relationship_description_embed", base_text_embed + ) + skip_description_embedding = config.get("skip_description_embedding", False) + + return [ + { + "verb": "unpack_graph", + "args": { + "column": "clustered_graph", + "type": "edges", + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, + { + "verb": "rename", + "args": {"columns": {"source_id": "text_unit_ids"}}, + }, + { + "verb": "filter", + "args": { + "column": "level", + "criteria": [{"type": "value", "operator": "equals", "value": 0}], + }, + }, + { + "verb": "text_embed", + "enabled": not skip_description_embedding, + "args": { + "embedding_name": "relationship_description", + "column": "description", + "to": "description_embedding", + **relationship_description_embed_config, + }, + }, + { + "id": "pruned_edges", + "verb": "drop", + "args": {"columns": ["level"]}, + }, + { + "id": "filtered_nodes", + "verb": "filter", + "args": { + "column": "level", + "criteria": [{"type": "value", "operator": "equals", "value": 0}], + }, + "input": "workflow:create_final_nodes", + }, + { + "verb": "compute_edge_combined_degree", + "args": {"to": "rank"}, + "input": { + "source": "pruned_edges", + "nodes": "filtered_nodes", + }, + }, + { + "verb": "convert", + "args": { + "column": "human_readable_id", + "type": "string", + "to": "human_readable_id", + }, + }, + { + "verb": "convert", + "args": { + "column": "text_unit_ids", + "type": "array", + "to": "text_unit_ids", + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_final_text_units.py b/func-app/graphrag/index/workflows/v1/create_final_text_units.py new file mode 100644 index 0000000000..56dd0a73d6 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_final_text_units.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_final_text_units" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final text-units table. + + ## Dependencies + * `workflow:create_base_text_units` + * `workflow:create_final_entities` + * `workflow:create_final_communities` + """ + base_text_embed = config.get("text_embed", {}) + text_unit_text_embed_config = config.get("text_unit_text_embed", base_text_embed) + covariates_enabled = config.get("covariates_enabled", False) + skip_text_unit_embedding = config.get("skip_text_unit_embedding", False) + is_using_vector_store = ( + text_unit_text_embed_config.get("strategy", {}).get("vector_store", None) + is not None + ) + + return [ + { + "verb": "select", + "args": {"columns": ["id", "chunk", "document_ids", "n_tokens"]}, + "input": {"source": "workflow:create_base_text_units"}, + }, + { + "id": "pre_entity_join", + "verb": "rename", + "args": { + "columns": { + "chunk": "text", + }, + }, + }, + # Expand the TextUnits with EntityIDs + { + "id": "pre_relationship_join", + "verb": "join", + "args": { + "on": ["id", "id"], + "strategy": "left outer", + }, + "input": { + "source": "pre_entity_join", + "others": ["workflow:join_text_units_to_entity_ids"], + }, + }, + # Expand the TextUnits with RelationshipIDs + { + "id": "pre_covariate_join", + "verb": "join", + "args": { + "on": ["id", "id"], + "strategy": "left outer", + }, + "input": { + "source": "pre_relationship_join", + "others": ["workflow:join_text_units_to_relationship_ids"], + }, + }, + # Expand the TextUnits with CovariateIDs + { + "enabled": covariates_enabled, + "verb": "join", + "args": { + "on": ["id", "id"], + "strategy": "left outer", + }, + "input": { + "source": "pre_covariate_join", + "others": ["workflow:join_text_units_to_covariate_ids"], + }, + }, + # Mash the entities and relationships into arrays + { + "verb": "aggregate_override", + "args": { + "groupby": ["id"], # from the join above + "aggregations": [ + { + "column": "text", + "operation": "any", + "to": "text", + }, + { + "column": "n_tokens", + "operation": "any", + "to": "n_tokens", + }, + { + "column": "document_ids", + "operation": "any", + "to": "document_ids", + }, + { + "column": "entity_ids", + "operation": "any", + "to": "entity_ids", + }, + { + "column": "relationship_ids", + "operation": "any", + "to": "relationship_ids", + }, + *( + [] + if not covariates_enabled + else [ + { + "column": "covariate_ids", + "operation": "any", + "to": "covariate_ids", + } + ] + ), + ], + }, + }, + # Text-Embed after final aggregations + { + "id": "embedded_text_units", + "verb": "text_embed", + "enabled": not skip_text_unit_embedding, + "args": { + "column": config.get("column", "text"), + "to": config.get("to", "text_embedding"), + **text_unit_text_embed_config, + }, + }, + { + "verb": "select", + "args": { + # Final select to get output in the correct shape + "columns": [ + "id", + "text", + *( + [] + if (skip_text_unit_embedding or is_using_vector_store) + else ["text_embedding"] + ), + "n_tokens", + "document_ids", + "entity_ids", + "relationship_ids", + *([] if not covariates_enabled else ["covariate_ids"]), + ], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/create_summarized_entities.py b/func-app/graphrag/index/workflows/v1/create_summarized_entities.py new file mode 100644 index 0000000000..8f8d7f0042 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/create_summarized_entities.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from datashaper import AsyncType + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "create_summarized_entities" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for extracted entities. + + ## Dependencies + * `workflow:create_base_text_units` + """ + summarize_descriptions_config = config.get("summarize_descriptions", {}) + graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False + + return [ + { + "verb": "summarize_descriptions", + "args": { + **summarize_descriptions_config, + "column": "entity_graph", + "to": "entity_graph", + "async_mode": summarize_descriptions_config.get( + "async_mode", AsyncType.AsyncIO + ), + }, + "input": {"source": "workflow:create_base_extracted_entities"}, + }, + { + "verb": "snapshot_rows", + "enabled": graphml_snapshot_enabled, + "args": { + "base_name": "summarized_graph", + "column": "entity_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/join_text_units_to_covariate_ids.py b/func-app/graphrag/index/workflows/v1/join_text_units_to_covariate_ids.py new file mode 100644 index 0000000000..be6bddf1e4 --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/join_text_units_to_covariate_ids.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "join_text_units_to_covariate_ids" + + +def build_steps( + _config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the final text-units table. + + ## Dependencies + * `workflow:create_final_covariates` + """ + return [ + { + "verb": "select", + "args": {"columns": ["id", "text_unit_id"]}, + "input": {"source": "workflow:create_final_covariates"}, + }, + { + "verb": "aggregate_override", + "args": { + "groupby": ["text_unit_id"], + "aggregations": [ + { + "column": "id", + "operation": "array_agg_distinct", + "to": "covariate_ids", + }, + { + "column": "text_unit_id", + "operation": "any", + "to": "id", + }, + ], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/join_text_units_to_entity_ids.py b/func-app/graphrag/index/workflows/v1/join_text_units_to_entity_ids.py new file mode 100644 index 0000000000..6337502d1a --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/join_text_units_to_entity_ids.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "join_text_units_to_entity_ids" + + +def build_steps( + _config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create a join table from text unit ids to entity ids. + + ## Dependencies + * `workflow:create_final_entities` + """ + return [ + { + "verb": "select", + "args": {"columns": ["id", "text_unit_ids"]}, + "input": {"source": "workflow:create_final_entities"}, + }, + { + "verb": "unroll", + "args": { + "column": "text_unit_ids", + }, + }, + { + "verb": "aggregate_override", + "args": { + "groupby": ["text_unit_ids"], + "aggregations": [ + { + "column": "id", + "operation": "array_agg_distinct", + "to": "entity_ids", + }, + { + "column": "text_unit_ids", + "operation": "any", + "to": "id", + }, + ], + }, + }, + ] diff --git a/func-app/graphrag/index/workflows/v1/join_text_units_to_relationship_ids.py b/func-app/graphrag/index/workflows/v1/join_text_units_to_relationship_ids.py new file mode 100644 index 0000000000..fe6d6463be --- /dev/null +++ b/func-app/graphrag/index/workflows/v1/join_text_units_to_relationship_ids.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from graphrag.index.config import PipelineWorkflowConfig, PipelineWorkflowStep + +workflow_name = "join_text_units_to_relationship_ids" + + +def build_steps( + _config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create a join table from text unit ids to relationship ids. + + ## Dependencies + * `workflow:create_final_relationships + """ + return [ + { + "verb": "select", + "args": {"columns": ["id", "text_unit_ids"]}, + "input": {"source": "workflow:create_final_relationships"}, + }, + { + "verb": "unroll", + "args": { + "column": "text_unit_ids", + }, + }, + { + "verb": "aggregate_override", + "args": { + "groupby": ["text_unit_ids"], + "aggregations": [ + { + "column": "id", + "operation": "array_agg_distinct", + "to": "relationship_ids", + }, + { + "column": "text_unit_ids", + "operation": "any", + "to": "id", + }, + ], + }, + }, + { + "id": "text_unit_id_to_relationship_ids", + "verb": "select", + "args": {"columns": ["id", "relationship_ids"]}, + }, + ] diff --git a/func-app/graphrag/llm/__init__.py b/func-app/graphrag/llm/__init__.py new file mode 100644 index 0000000000..609be951b2 --- /dev/null +++ b/func-app/graphrag/llm/__init__.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Datashaper OpenAI Utilities package.""" + +from .base import BaseLLM, CachingLLM, RateLimitingLLM +from .errors import RetriesExhaustedError +from .limiting import ( + CompositeLLMLimiter, + LLMLimiter, + NoopLLMLimiter, + TpmRpmLLMLimiter, + create_tpm_rpm_limiters, +) +from .mock import MockChatLLM, MockCompletionLLM +from .openai import ( + OpenAIChatLLM, + OpenAIClientTypes, + OpenAICompletionLLM, + OpenAIConfiguration, + OpenAIEmbeddingsLLM, + create_openai_chat_llm, + create_openai_client, + create_openai_completion_llm, + create_openai_embedding_llm, +) +from .types import ( + LLM, + CompletionInput, + CompletionLLM, + CompletionOutput, + EmbeddingInput, + EmbeddingLLM, + EmbeddingOutput, + ErrorHandlerFn, + IsResponseValidFn, + LLMCache, + LLMConfig, + LLMInput, + LLMInvocationFn, + LLMInvocationResult, + LLMOutput, + OnCacheActionFn, +) + +__all__ = [ + # LLM Types + "LLM", + "BaseLLM", + "CachingLLM", + "CompletionInput", + "CompletionLLM", + "CompletionOutput", + "CompositeLLMLimiter", + "EmbeddingInput", + "EmbeddingLLM", + "EmbeddingOutput", + # Callbacks + "ErrorHandlerFn", + "IsResponseValidFn", + # Cache + "LLMCache", + "LLMConfig", + # LLM I/O Types + "LLMInput", + "LLMInvocationFn", + "LLMInvocationResult", + "LLMLimiter", + "LLMOutput", + "MockChatLLM", + # Mock + "MockCompletionLLM", + "NoopLLMLimiter", + "OnCacheActionFn", + "OpenAIChatLLM", + "OpenAIClientTypes", + "OpenAICompletionLLM", + # OpenAI + "OpenAIConfiguration", + "OpenAIEmbeddingsLLM", + "RateLimitingLLM", + # Errors + "RetriesExhaustedError", + "TpmRpmLLMLimiter", + "create_openai_chat_llm", + "create_openai_client", + "create_openai_completion_llm", + "create_openai_embedding_llm", + # Limiters + "create_tpm_rpm_limiters", +] diff --git a/func-app/graphrag/llm/base/__init__.py b/func-app/graphrag/llm/base/__init__.py new file mode 100644 index 0000000000..dd5ebf9050 --- /dev/null +++ b/func-app/graphrag/llm/base/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base LLM Implementations.""" + +from .base_llm import BaseLLM +from .caching_llm import CachingLLM +from .rate_limiting_llm import RateLimitingLLM + +__all__ = ["BaseLLM", "CachingLLM", "RateLimitingLLM"] diff --git a/func-app/graphrag/llm/base/_create_cache_key.py b/func-app/graphrag/llm/base/_create_cache_key.py new file mode 100644 index 0000000000..b5fdd839bc --- /dev/null +++ b/func-app/graphrag/llm/base/_create_cache_key.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Cache key generation utils.""" + +import hashlib +import json + + +def _llm_string(params: dict) -> str: + # New version of the cache is not including n in the params dictionary + # This avoids creating a new cache key for the same prompt + if "max_tokens" in params and "n" not in params: + params["n"] = None + return str(sorted((k, v) for k, v in params.items())) + + +def _hash(_input: str) -> str: + """Use a deterministic hashing approach.""" + return hashlib.md5(_input.encode()).hexdigest() # noqa S324 + + +def create_hash_key( + operation: str, prompt: str, parameters: dict, history: list[dict] | None +) -> str: + """Compute cache key from prompt and associated model and settings. + + Args: + prompt (str): The prompt run through the language model. + llm_string (str): The language model version and settings. + + Returns + ------- + str: The cache key. + """ + llm_string = _llm_string(parameters) + history_string = _hash(json.dumps(history)) if history else None + hash_string = ( + _hash(prompt + llm_string + history_string) + if history_string + else _hash(prompt + llm_string) + ) + return f"{operation}-{hash_string}" diff --git a/func-app/graphrag/llm/base/base_llm.py b/func-app/graphrag/llm/base/base_llm.py new file mode 100644 index 0000000000..66f1919cd4 --- /dev/null +++ b/func-app/graphrag/llm/base/base_llm.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base LLM class definition.""" + +import traceback +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from typing_extensions import Unpack + +from graphrag.llm.types import ( + LLM, + ErrorHandlerFn, + LLMInput, + LLMOutput, +) + +TIn = TypeVar("TIn") +TOut = TypeVar("TOut") + + +class BaseLLM(ABC, LLM[TIn, TOut], Generic[TIn, TOut]): + """LLM Implementation class definition.""" + + _on_error: ErrorHandlerFn | None + + def on_error(self, on_error: ErrorHandlerFn | None) -> None: + """Set the error handler function.""" + self._on_error = on_error + + @abstractmethod + async def _execute_llm( + self, + input: TIn, + **kwargs: Unpack[LLMInput], + ) -> TOut | None: + pass + + async def __call__( + self, + input: TIn, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[TOut]: + """Invoke the LLM.""" + is_json = kwargs.get("json") or False + if is_json: + return await self._invoke_json(input, **kwargs) + return await self._invoke(input, **kwargs) + + async def _invoke(self, input: TIn, **kwargs: Unpack[LLMInput]) -> LLMOutput[TOut]: + try: + output = await self._execute_llm(input, **kwargs) + return LLMOutput(output=output) + except Exception as e: + stack_trace = traceback.format_exc() + if self._on_error: + self._on_error(e, stack_trace, {"input": input}) + raise + + async def _invoke_json( + self, input: TIn, **kwargs: Unpack[LLMInput] + ) -> LLMOutput[TOut]: + msg = "JSON output not supported by this LLM" + raise NotImplementedError(msg) diff --git a/func-app/graphrag/llm/base/caching_llm.py b/func-app/graphrag/llm/base/caching_llm.py new file mode 100644 index 0000000000..c039de5122 --- /dev/null +++ b/func-app/graphrag/llm/base/caching_llm.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A class to interact with the cache.""" + +import json +from typing import Generic, TypeVar + +from typing_extensions import Unpack + +from graphrag.llm.types import LLM, LLMCache, LLMInput, LLMOutput, OnCacheActionFn + +from ._create_cache_key import create_hash_key + +# If there's a breaking change in what we cache, we should increment this version number to invalidate existing caches +_cache_strategy_version = 2 + +TIn = TypeVar("TIn") +TOut = TypeVar("TOut") + + +def _noop_cache_fn(_k: str, _v: str | None): + pass + + +class CachingLLM(LLM[TIn, TOut], Generic[TIn, TOut]): + """A class to interact with the cache.""" + + _cache: LLMCache + _delegate: LLM[TIn, TOut] + _operation: str + _llm_parameters: dict + _on_cache_hit: OnCacheActionFn + _on_cache_miss: OnCacheActionFn + + def __init__( + self, + delegate: LLM[TIn, TOut], + llm_parameters: dict, + operation: str, + cache: LLMCache, + ): + self._delegate = delegate + self._llm_parameters = llm_parameters + self._cache = cache + self._operation = operation + self._on_cache_hit = _noop_cache_fn + self._on_cache_miss = _noop_cache_fn + + def set_delegate(self, delegate: LLM[TIn, TOut]) -> None: + """Set the delegate LLM. (for testing).""" + self._delegate = delegate + + def on_cache_hit(self, fn: OnCacheActionFn | None) -> None: + """Set the function to call when a cache hit occurs.""" + self._on_cache_hit = fn or _noop_cache_fn + + def on_cache_miss(self, fn: OnCacheActionFn | None) -> None: + """Set the function to call when a cache miss occurs.""" + self._on_cache_miss = fn or _noop_cache_fn + + def _cache_key( + self, input: TIn, name: str | None, args: dict, history: list[dict] | None + ) -> str: + json_input = json.dumps(input) + tag = ( + f"{name}-{self._operation}-v{_cache_strategy_version}" + if name is not None + else self._operation + ) + return create_hash_key(tag, json_input, args, history) + + async def __call__( + self, + input: TIn, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[TOut]: + """Execute the LLM.""" + # Check for an Existing cache item + name = kwargs.get("name") + history_in = kwargs.get("history") or None + llm_args = {**self._llm_parameters, **(kwargs.get("model_parameters") or {})} + cache_key = self._cache_key(input, name, llm_args, history_in) + cached_result = await self._cache.get(cache_key) + + if cached_result: + self._on_cache_hit(cache_key, name) + return LLMOutput( + output=cached_result, + ) + + # Report the Cache Miss + self._on_cache_miss(cache_key, name) + + # Compute the new result + result = await self._delegate(input, **kwargs) + + # Cache the new result + if result.output is not None: + await self._cache.set( + cache_key, + result.output, + { + "input": input, + "parameters": llm_args, + "history": history_in, + }, + ) + return result diff --git a/func-app/graphrag/llm/base/rate_limiting_llm.py b/func-app/graphrag/llm/base/rate_limiting_llm.py new file mode 100644 index 0000000000..5e2082475f --- /dev/null +++ b/func-app/graphrag/llm/base/rate_limiting_llm.py @@ -0,0 +1,208 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Rate limiting LLM implementation.""" + +import asyncio +import logging +from collections.abc import Callable +from typing import Any, Generic, TypeVar + +from tenacity import ( + AsyncRetrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) +from typing_extensions import Unpack + +from graphrag.llm.errors import RetriesExhaustedError +from graphrag.llm.limiting import LLMLimiter +from graphrag.llm.types import ( + LLM, + LLMConfig, + LLMInput, + LLMInvocationFn, + LLMInvocationResult, + LLMOutput, +) + +TIn = TypeVar("TIn") +TOut = TypeVar("TOut") +TRateLimitError = TypeVar("TRateLimitError", bound=BaseException) + +_CANNOT_MEASURE_INPUT_TOKENS_MSG = "cannot measure input tokens" +_CANNOT_MEASURE_OUTPUT_TOKENS_MSG = "cannot measure output tokens" + +log = logging.getLogger(__name__) + + +class RateLimitingLLM(LLM[TIn, TOut], Generic[TIn, TOut]): + """A class to interact with the cache.""" + + _delegate: LLM[TIn, TOut] + _rate_limiter: LLMLimiter | None + _semaphore: asyncio.Semaphore | None + _count_tokens: Callable[[str], int] + _config: LLMConfig + _operation: str + _retryable_errors: list[type[Exception]] + _rate_limit_errors: list[type[Exception]] + _on_invoke: LLMInvocationFn + _extract_sleep_recommendation: Callable[[Any], float] + + def __init__( + self, + delegate: LLM[TIn, TOut], + config: LLMConfig, + operation: str, + retryable_errors: list[type[Exception]], + rate_limit_errors: list[type[Exception]], + rate_limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + count_tokens: Callable[[str], int] | None = None, + get_sleep_time: Callable[[BaseException], float] | None = None, + ): + self._delegate = delegate + self._rate_limiter = rate_limiter + self._semaphore = semaphore + self._config = config + self._operation = operation + self._retryable_errors = retryable_errors + self._rate_limit_errors = rate_limit_errors + self._count_tokens = count_tokens or (lambda _s: -1) + self._extract_sleep_recommendation = get_sleep_time or (lambda _e: 0.0) + self._on_invoke = lambda _v: None + + def on_invoke(self, fn: LLMInvocationFn | None) -> None: + """Set the on_invoke function.""" + self._on_invoke = fn or (lambda _v: None) + + def count_request_tokens(self, input: TIn) -> int: + """Count the request tokens on an input request.""" + if isinstance(input, str): + return self._count_tokens(input) + if isinstance(input, list): + result = 0 + for item in input: + if isinstance(item, str): + result += self._count_tokens(item) + elif isinstance(item, dict): + result += self._count_tokens(item.get("content", "")) + else: + raise TypeError(_CANNOT_MEASURE_INPUT_TOKENS_MSG) + return result + raise TypeError(_CANNOT_MEASURE_INPUT_TOKENS_MSG) + + def count_response_tokens(self, output: TOut | None) -> int: + """Count the request tokens on an output response.""" + if output is None: + return 0 + if isinstance(output, str): + return self._count_tokens(output) + if isinstance(output, list) and all(isinstance(x, str) for x in output): + return sum(self._count_tokens(item) for item in output) + if isinstance(output, list): + # Embedding response, don't count it + return 0 + raise TypeError(_CANNOT_MEASURE_OUTPUT_TOKENS_MSG) + + async def __call__( + self, + input: TIn, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[TOut]: + """Execute the LLM with semaphore & rate limiting.""" + name = kwargs.get("name", "Process") + attempt_number = 0 + call_times: list[float] = [] + input_tokens = self.count_request_tokens(input) + max_retries = self._config.max_retries or 10 + max_retry_wait = self._config.max_retry_wait or 10 + follow_recommendation = self._config.sleep_on_rate_limit_recommendation + retryer = AsyncRetrying( + stop=stop_after_attempt(max_retries), + wait=wait_exponential_jitter(max=max_retry_wait), + reraise=True, + retry=retry_if_exception_type(tuple(self._retryable_errors)), + ) + + async def sleep_for(time: float | None) -> None: + log.warning( + "%s failed to invoke LLM %s/%s attempts. Cause: rate limit exceeded, will retry. Recommended sleep for %d seconds. Follow recommendation? %s", + name, + attempt_number, + max_retries, + time, + follow_recommendation, + ) + if follow_recommendation and time: + await asyncio.sleep(time) + raise + + async def do_attempt() -> LLMOutput[TOut]: + nonlocal call_times + call_start = asyncio.get_event_loop().time() + try: + return await self._delegate(input, **kwargs) + except BaseException as e: + if isinstance(e, tuple(self._rate_limit_errors)): + sleep_time = self._extract_sleep_recommendation(e) + await sleep_for(sleep_time) + raise + finally: + call_end = asyncio.get_event_loop().time() + call_times.append(call_end - call_start) + + async def execute_with_retry() -> tuple[LLMOutput[TOut], float]: + nonlocal attempt_number + async for attempt in retryer: + with attempt: + if self._rate_limiter and input_tokens > 0: + await self._rate_limiter.acquire(input_tokens) + start = asyncio.get_event_loop().time() + attempt_number += 1 + return await do_attempt(), start + + log.error("Retries exhausted for %s", name) + raise RetriesExhaustedError(name, max_retries) + + result: LLMOutput[TOut] + start = 0.0 + + if self._semaphore is None: + result, start = await execute_with_retry() + else: + async with self._semaphore: + result, start = await execute_with_retry() + + end = asyncio.get_event_loop().time() + output_tokens = self.count_response_tokens(result.output) + if self._rate_limiter and output_tokens > 0: + await self._rate_limiter.acquire(output_tokens) + + invocation_result = LLMInvocationResult( + result=result, + name=name, + num_retries=attempt_number - 1, + total_time=end - start, + call_times=call_times, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + self._handle_invoke_result(invocation_result) + return result + + def _handle_invoke_result( + self, result: LLMInvocationResult[LLMOutput[TOut]] + ) -> None: + log.info( + 'perf - llm.%s "%s" with %s retries took %s. input_tokens=%d, output_tokens=%d', + self._operation, + result.name, + result.num_retries, + result.total_time, + result.input_tokens, + result.output_tokens, + ) + self._on_invoke(result) diff --git a/func-app/graphrag/llm/errors.py b/func-app/graphrag/llm/errors.py new file mode 100644 index 0000000000..01136359de --- /dev/null +++ b/func-app/graphrag/llm/errors.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Error definitions for the OpenAI DataShaper package.""" + + +class RetriesExhaustedError(RuntimeError): + """Retries exhausted error.""" + + def __init__(self, name: str, num_retries: int) -> None: + """Init method definition.""" + super().__init__(f"Operation '{name}' failed - {num_retries} retries exhausted") diff --git a/func-app/graphrag/llm/limiting/__init__.py b/func-app/graphrag/llm/limiting/__init__.py new file mode 100644 index 0000000000..4f7933d1a8 --- /dev/null +++ b/func-app/graphrag/llm/limiting/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM limiters module.""" + +from .composite_limiter import CompositeLLMLimiter +from .create_limiters import create_tpm_rpm_limiters +from .llm_limiter import LLMLimiter +from .noop_llm_limiter import NoopLLMLimiter +from .tpm_rpm_limiter import TpmRpmLLMLimiter + +__all__ = [ + "CompositeLLMLimiter", + "LLMLimiter", + "NoopLLMLimiter", + "TpmRpmLLMLimiter", + "create_tpm_rpm_limiters", +] diff --git a/func-app/graphrag/llm/limiting/composite_limiter.py b/func-app/graphrag/llm/limiting/composite_limiter.py new file mode 100644 index 0000000000..7bcf9195b2 --- /dev/null +++ b/func-app/graphrag/llm/limiting/composite_limiter.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing Composite Limiter class definition.""" + +from .llm_limiter import LLMLimiter + + +class CompositeLLMLimiter(LLMLimiter): + """Composite Limiter class definition.""" + + _limiters: list[LLMLimiter] + + def __init__(self, limiters: list[LLMLimiter]): + """Init method definition.""" + self._limiters = limiters + + @property + def needs_token_count(self) -> bool: + """Whether this limiter needs the token count to be passed in.""" + return any(limiter.needs_token_count for limiter in self._limiters) + + async def acquire(self, num_tokens: int = 1) -> None: + """Call method definition.""" + for limiter in self._limiters: + await limiter.acquire(num_tokens) diff --git a/func-app/graphrag/llm/limiting/create_limiters.py b/func-app/graphrag/llm/limiting/create_limiters.py new file mode 100644 index 0000000000..92df11c1a6 --- /dev/null +++ b/func-app/graphrag/llm/limiting/create_limiters.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Create limiters for OpenAI API requests.""" + +import logging + +from aiolimiter import AsyncLimiter + +from graphrag.llm.types import LLMConfig + +from .llm_limiter import LLMLimiter +from .tpm_rpm_limiter import TpmRpmLLMLimiter + +log = logging.getLogger(__name__) + +"""The global TPM limiters.""" + + +def create_tpm_rpm_limiters( + configuration: LLMConfig, +) -> LLMLimiter: + """Get the limiters for a given model name.""" + tpm = configuration.tokens_per_minute + rpm = configuration.requests_per_minute + return TpmRpmLLMLimiter( + None if tpm == 0 else AsyncLimiter(tpm or 50_000), + None if rpm == 0 else AsyncLimiter(rpm or 10_000), + ) diff --git a/func-app/graphrag/llm/limiting/llm_limiter.py b/func-app/graphrag/llm/limiting/llm_limiter.py new file mode 100644 index 0000000000..1264a84be5 --- /dev/null +++ b/func-app/graphrag/llm/limiting/llm_limiter.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Limiting types.""" + +from abc import ABC, abstractmethod + + +class LLMLimiter(ABC): + """LLM Limiter Interface.""" + + @property + @abstractmethod + def needs_token_count(self) -> bool: + """Whether this limiter needs the token count to be passed in.""" + + @abstractmethod + async def acquire(self, num_tokens: int = 1) -> None: + """Acquire a pass through the limiter.""" diff --git a/func-app/graphrag/llm/limiting/noop_llm_limiter.py b/func-app/graphrag/llm/limiting/noop_llm_limiter.py new file mode 100644 index 0000000000..5147055255 --- /dev/null +++ b/func-app/graphrag/llm/limiting/noop_llm_limiter.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""TPM RPM Limiter module.""" + +from .llm_limiter import LLMLimiter + + +class NoopLLMLimiter(LLMLimiter): + """TPM RPM Limiter class definition.""" + + @property + def needs_token_count(self) -> bool: + """Whether this limiter needs the token count to be passed in.""" + return False + + async def acquire(self, num_tokens: int = 1) -> None: + """Call method definition.""" + # do nothing diff --git a/func-app/graphrag/llm/limiting/tpm_rpm_limiter.py b/func-app/graphrag/llm/limiting/tpm_rpm_limiter.py new file mode 100644 index 0000000000..cb6d84e377 --- /dev/null +++ b/func-app/graphrag/llm/limiting/tpm_rpm_limiter.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""TPM RPM Limiter module.""" + +from aiolimiter import AsyncLimiter + +from .llm_limiter import LLMLimiter + + +class TpmRpmLLMLimiter(LLMLimiter): + """TPM RPM Limiter class definition.""" + + _tpm_limiter: AsyncLimiter | None + _rpm_limiter: AsyncLimiter | None + + def __init__( + self, tpm_limiter: AsyncLimiter | None, rpm_limiter: AsyncLimiter | None + ): + """Init method definition.""" + self._tpm_limiter = tpm_limiter + self._rpm_limiter = rpm_limiter + + @property + def needs_token_count(self) -> bool: + """Whether this limiter needs the token count to be passed in.""" + return self._tpm_limiter is not None + + async def acquire(self, num_tokens: int = 1) -> None: + """Call method definition.""" + if self._tpm_limiter is not None: + await self._tpm_limiter.acquire(num_tokens) + if self._rpm_limiter is not None: + await self._rpm_limiter.acquire() diff --git a/func-app/graphrag/llm/mock/__init__.py b/func-app/graphrag/llm/mock/__init__.py new file mode 100644 index 0000000000..cd1f000dd1 --- /dev/null +++ b/func-app/graphrag/llm/mock/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Mock LLM Implementations.""" + +from .mock_chat_llm import MockChatLLM +from .mock_completion_llm import MockCompletionLLM + +__all__ = [ + "MockChatLLM", + "MockCompletionLLM", +] diff --git a/func-app/graphrag/llm/mock/mock_chat_llm.py b/func-app/graphrag/llm/mock/mock_chat_llm.py new file mode 100644 index 0000000000..b8a6650b31 --- /dev/null +++ b/func-app/graphrag/llm/mock/mock_chat_llm.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A mock ChatLLM that returns the given responses.""" + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, + LLMOutput, +) + + +class MockChatLLM( + BaseLLM[ + CompletionInput, + CompletionOutput, + ] +): + """A mock LLM that returns the given responses.""" + + responses: list[str] + i: int = 0 + + def __init__(self, responses: list[str]): + self.i = 0 + self.responses = responses + + def _create_output( + self, + output: CompletionOutput | None, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + history = kwargs.get("history") or [] + return LLMOutput[CompletionOutput]( + output=output, history=[*history, {"content": output}] + ) + + async def _execute_llm( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> CompletionOutput: + if self.i >= len(self.responses): + msg = f"No more responses, requested {self.i} but only have {len(self.responses)}" + raise ValueError(msg) + response = self.responses[self.i] + self.i += 1 + return response diff --git a/func-app/graphrag/llm/mock/mock_completion_llm.py b/func-app/graphrag/llm/mock/mock_completion_llm.py new file mode 100644 index 0000000000..8cb8e95083 --- /dev/null +++ b/func-app/graphrag/llm/mock/mock_completion_llm.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Static Response method definition.""" + +import logging + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, +) + +log = logging.getLogger(__name__) + + +class MockCompletionLLM( + BaseLLM[ + CompletionInput, + CompletionOutput, + ] +): + """Mock Completion LLM for testing purposes.""" + + def __init__(self, responses: list[str]): + self.responses = responses + self._on_error = None + + async def _execute_llm( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> CompletionOutput: + return self.responses[0] diff --git a/func-app/graphrag/llm/openai/__init__.py b/func-app/graphrag/llm/openai/__init__.py new file mode 100644 index 0000000000..9478e146d2 --- /dev/null +++ b/func-app/graphrag/llm/openai/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""OpenAI LLM implementations.""" + +from .create_openai_client import create_openai_client +from .factories import ( + create_openai_chat_llm, + create_openai_completion_llm, + create_openai_embedding_llm, +) +from .openai_chat_llm import OpenAIChatLLM +from .openai_completion_llm import OpenAICompletionLLM +from .openai_configuration import OpenAIConfiguration +from .openai_embeddings_llm import OpenAIEmbeddingsLLM +from .types import OpenAIClientTypes + +__all__ = [ + "OpenAIChatLLM", + "OpenAIClientTypes", + "OpenAICompletionLLM", + "OpenAIConfiguration", + "OpenAIEmbeddingsLLM", + "create_openai_chat_llm", + "create_openai_client", + "create_openai_completion_llm", + "create_openai_embedding_llm", +] diff --git a/func-app/graphrag/llm/openai/_prompts.py b/func-app/graphrag/llm/openai/_prompts.py new file mode 100644 index 0000000000..37d9f0fc70 --- /dev/null +++ b/func-app/graphrag/llm/openai/_prompts.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utility prompts for low-level LLM invocations.""" + +JSON_CHECK_PROMPT = """ +You are going to be given a malformed JSON string that threw an error during json.loads. +It probably contains unnecessary escape sequences, or it is missing a comma or colon somewhere. +Your task is to fix this string and return a well-formed JSON string containing a single object. +Eliminate any unnecessary escape sequences. +Only return valid JSON, parseable with json.loads, without commentary. + +# Examples +----------- +Text: {{ \\"title\\": \\"abc\\", \\"summary\\": \\"def\\" }} +Output: {{"title": "abc", "summary": "def"}} +----------- +Text: {{"title": "abc", "summary": "def" +Output: {{"title": "abc", "summary": "def"}} +----------- +Text: {{"title': "abc", 'summary": "def" +Output: {{"title": "abc", "summary": "def"}} +----------- +Text: "{{"title": "abc", "summary": "def"}}" +Output: {{"title": "abc", "summary": "def"}} +----------- +Text: [{{"title": "abc", "summary": "def"}}] +Output: [{{"title": "abc", "summary": "def"}}] +----------- +Text: [{{"title": "abc", "summary": "def"}}, {{ \\"title\\": \\"abc\\", \\"summary\\": \\"def\\" }}] +Output: [{{"title": "abc", "summary": "def"}}, {{"title": "abc", "summary": "def"}}] +----------- +Text: ```json\n[{{"title": "abc", "summary": "def"}}, {{ \\"title\\": \\"abc\\", \\"summary\\": \\"def\\" }}]``` +Output: [{{"title": "abc", "summary": "def"}}, {{"title": "abc", "summary": "def"}}] + + +# Real Data +Text: {input_text} +Output:""" diff --git a/func-app/graphrag/llm/openai/create_openai_client.py b/func-app/graphrag/llm/openai/create_openai_client.py new file mode 100644 index 0000000000..cd149323c6 --- /dev/null +++ b/func-app/graphrag/llm/openai/create_openai_client.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Create OpenAI client instance.""" + +import logging +from functools import cache + +from azure.identity import ManagedIdentityCredential, get_bearer_token_provider +from openai import AsyncAzureOpenAI, AsyncOpenAI + +from .openai_configuration import OpenAIConfiguration +from .types import OpenAIClientTypes + +log = logging.getLogger(__name__) + +API_BASE_REQUIRED_FOR_AZURE = "api_base is required for Azure OpenAI client" + + +@cache +def create_openai_client( + configuration: OpenAIConfiguration, azure: bool +) -> OpenAIClientTypes: + """Create a new OpenAI client instance.""" + if azure: + api_base = configuration.api_base + if api_base is None: + raise ValueError(API_BASE_REQUIRED_FOR_AZURE) + + log.info( + "Creating Azure OpenAI client api_base=%s, deployment_name=%s", + api_base, + configuration.deployment_name, + ) + if configuration.cognitive_services_endpoint is None: + cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default" + else: + cognitive_services_endpoint = configuration.cognitive_services_endpoint + + return AsyncAzureOpenAI( + api_key=configuration.api_key if configuration.api_key else None, + azure_ad_token_provider=get_bearer_token_provider( + ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint + ) + if not configuration.api_key + else None, + organization=configuration.organization, + # Azure-Specifics + api_version=configuration.api_version, + azure_endpoint=api_base, + azure_deployment=configuration.deployment_name, + # Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here + timeout=configuration.request_timeout or 180.0, + max_retries=0, + ) + + log.info("Creating OpenAI client base_url=%s", configuration.api_base) + return AsyncOpenAI( + api_key=configuration.api_key, + base_url=configuration.api_base, + organization=configuration.organization, + # Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here + timeout=configuration.request_timeout or 180.0, + max_retries=0, + ) diff --git a/func-app/graphrag/llm/openai/factories.py b/func-app/graphrag/llm/openai/factories.py new file mode 100644 index 0000000000..e595e2e55b --- /dev/null +++ b/func-app/graphrag/llm/openai/factories.py @@ -0,0 +1,140 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Factory functions for creating OpenAI LLMs.""" + +import asyncio + +from graphrag.llm.base import CachingLLM, RateLimitingLLM +from graphrag.llm.limiting import LLMLimiter +from graphrag.llm.types import ( + LLM, + CompletionLLM, + EmbeddingLLM, + ErrorHandlerFn, + LLMCache, + LLMInvocationFn, + OnCacheActionFn, +) + +from .json_parsing_llm import JsonParsingLLM +from .openai_chat_llm import OpenAIChatLLM +from .openai_completion_llm import OpenAICompletionLLM +from .openai_configuration import OpenAIConfiguration +from .openai_embeddings_llm import OpenAIEmbeddingsLLM +from .openai_history_tracking_llm import OpenAIHistoryTrackingLLM +from .openai_token_replacing_llm import OpenAITokenReplacingLLM +from .types import OpenAIClientTypes +from .utils import ( + RATE_LIMIT_ERRORS, + RETRYABLE_ERRORS, + get_completion_cache_args, + get_sleep_time_from_error, + get_token_counter, +) + + +def create_openai_chat_llm( + client: OpenAIClientTypes, + config: OpenAIConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> CompletionLLM: + """Create an OpenAI chat LLM.""" + operation = "chat" + result = OpenAIChatLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + result = OpenAIHistoryTrackingLLM(result) + result = OpenAITokenReplacingLLM(result) + return JsonParsingLLM(result) + + +def create_openai_completion_llm( + client: OpenAIClientTypes, + config: OpenAIConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> CompletionLLM: + """Create an OpenAI completion LLM.""" + operation = "completion" + result = OpenAICompletionLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + return OpenAITokenReplacingLLM(result) + + +def create_openai_embedding_llm( + client: OpenAIClientTypes, + config: OpenAIConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> EmbeddingLLM: + """Create an OpenAI embeddings LLM.""" + operation = "embedding" + result = OpenAIEmbeddingsLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + return result + + +def _rate_limited( + delegate: LLM, + config: OpenAIConfiguration, + operation: str, + limiter: LLMLimiter | None, + semaphore: asyncio.Semaphore | None, + on_invoke: LLMInvocationFn | None, +): + result = RateLimitingLLM( + delegate, + config, + operation, + RETRYABLE_ERRORS, + RATE_LIMIT_ERRORS, + limiter, + semaphore, + get_token_counter(config), + get_sleep_time_from_error, + ) + result.on_invoke(on_invoke) + return result + + +def _cached( + delegate: LLM, + config: OpenAIConfiguration, + operation: str, + cache: LLMCache, + on_cache_hit: OnCacheActionFn | None, + on_cache_miss: OnCacheActionFn | None, +): + cache_args = get_completion_cache_args(config) + result = CachingLLM(delegate, cache_args, operation, cache) + result.on_cache_hit(on_cache_hit) + result.on_cache_miss(on_cache_miss) + return result diff --git a/func-app/graphrag/llm/openai/json_parsing_llm.py b/func-app/graphrag/llm/openai/json_parsing_llm.py new file mode 100644 index 0000000000..009c1da42e --- /dev/null +++ b/func-app/graphrag/llm/openai/json_parsing_llm.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""An LLM that unpacks cached JSON responses.""" + +from typing_extensions import Unpack + +from graphrag.llm.types import ( + LLM, + CompletionInput, + CompletionLLM, + CompletionOutput, + LLMInput, + LLMOutput, +) + +from .utils import try_parse_json_object + + +class JsonParsingLLM(LLM[CompletionInput, CompletionOutput]): + """An OpenAI History-Tracking LLM.""" + + _delegate: CompletionLLM + + def __init__(self, delegate: CompletionLLM): + self._delegate = delegate + + async def __call__( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Call the LLM with the input and kwargs.""" + result = await self._delegate(input, **kwargs) + if kwargs.get("json") and result.json is None and result.output is not None: + _, parsed_json = try_parse_json_object(result.output) + result.json = parsed_json + return result diff --git a/func-app/graphrag/llm/openai/openai_chat_llm.py b/func-app/graphrag/llm/openai/openai_chat_llm.py new file mode 100644 index 0000000000..d08f8af80c --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_chat_llm.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Chat-based language model.""" + +import logging + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, + LLMOutput, +) + +from ._prompts import JSON_CHECK_PROMPT +from .openai_configuration import OpenAIConfiguration +from .types import OpenAIClientTypes +from .utils import ( + get_completion_llm_args, + try_parse_json_object, +) + +log = logging.getLogger(__name__) + +_MAX_GENERATION_RETRIES = 3 +FAILED_TO_CREATE_JSON_ERROR = "Failed to generate valid JSON output" + + +class OpenAIChatLLM(BaseLLM[CompletionInput, CompletionOutput]): + """A Chat-based LLM.""" + + _client: OpenAIClientTypes + _configuration: OpenAIConfiguration + + def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, input: CompletionInput, **kwargs: Unpack[LLMInput] + ) -> CompletionOutput | None: + args = get_completion_llm_args( + kwargs.get("model_parameters"), self.configuration + ) + history = kwargs.get("history") or [] + messages = [ + *history, + {"role": "user", "content": input}, + ] + completion = await self.client.chat.completions.create( + messages=messages, **args + ) + return completion.choices[0].message.content + + async def _invoke_json( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Generate JSON output.""" + name = kwargs.get("name") or "unknown" + is_response_valid = kwargs.get("is_response_valid") or (lambda _x: True) + + async def generate( + attempt: int | None = None, + ) -> LLMOutput[CompletionOutput]: + call_name = name if attempt is None else f"{name}@{attempt}" + return ( + await self._native_json(input, **{**kwargs, "name": call_name}) + if self.configuration.model_supports_json + else await self._manual_json(input, **{**kwargs, "name": call_name}) + ) + + def is_valid(x: dict | None) -> bool: + return x is not None and is_response_valid(x) + + result = await generate() + retry = 0 + while not is_valid(result.json) and retry < _MAX_GENERATION_RETRIES: + result = await generate(retry) + retry += 1 + + if is_valid(result.json): + return result + raise RuntimeError(FAILED_TO_CREATE_JSON_ERROR) + + async def _native_json( + self, input: CompletionInput, **kwargs: Unpack[LLMInput] + ) -> LLMOutput[CompletionOutput]: + """Generate JSON output using a model's native JSON-output support.""" + result = await self._invoke( + input, + **{ + **kwargs, + "model_parameters": { + **(kwargs.get("model_parameters") or {}), + "response_format": {"type": "json_object"}, + }, + }, + ) + + output, json_output = try_parse_json_object(result.output or "") + + return LLMOutput[CompletionOutput]( + output=output, + json=json_output, + history=result.history, + ) + + async def _manual_json( + self, input: CompletionInput, **kwargs: Unpack[LLMInput] + ) -> LLMOutput[CompletionOutput]: + # Otherwise, clean up the output and try to parse it as json + result = await self._invoke(input, **kwargs) + history = result.history or [] + output, json_output = try_parse_json_object(result.output or "") + if json_output: + return LLMOutput[CompletionOutput]( + output=result.output, json=json_output, history=history + ) + # if not return correct formatted json, retry + log.warning("error parsing llm json, retrying") + + # If cleaned up json is unparsable, use the LLM to reformat it (may throw) + result = await self._try_clean_json_with_llm(output, **kwargs) + output, json_output = try_parse_json_object(result.output or "") + + return LLMOutput[CompletionOutput]( + output=output, + json=json_output, + history=history, + ) + + async def _try_clean_json_with_llm( + self, output: str, **kwargs: Unpack[LLMInput] + ) -> LLMOutput[CompletionOutput]: + name = kwargs.get("name") or "unknown" + return await self._invoke( + JSON_CHECK_PROMPT, + **{ + **kwargs, + "variables": {"input_text": output}, + "name": f"fix_json@{name}", + }, + ) diff --git a/func-app/graphrag/llm/openai/openai_completion_llm.py b/func-app/graphrag/llm/openai/openai_completion_llm.py new file mode 100644 index 0000000000..bdbac6c131 --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_completion_llm.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A text-completion based LLM.""" + +import logging + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, +) + +from .openai_configuration import OpenAIConfiguration +from .types import OpenAIClientTypes +from .utils import get_completion_llm_args + +log = logging.getLogger(__name__) + + +class OpenAICompletionLLM(BaseLLM[CompletionInput, CompletionOutput]): + """A text-completion based LLM.""" + + _client: OpenAIClientTypes + _configuration: OpenAIConfiguration + + def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> CompletionOutput | None: + args = get_completion_llm_args( + kwargs.get("model_parameters"), self.configuration + ) + completion = self.client.completions.create(prompt=input, **args) + return completion.choices[0].text diff --git a/func-app/graphrag/llm/openai/openai_configuration.py b/func-app/graphrag/llm/openai/openai_configuration.py new file mode 100644 index 0000000000..1bcd5694d6 --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_configuration.py @@ -0,0 +1,288 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""OpenAI Configuration class definition.""" + +import json +from collections.abc import Hashable +from typing import Any, cast + +from graphrag.llm.types import LLMConfig + + +def _non_blank(value: str | None) -> str | None: + if value is None: + return None + stripped = value.strip() + return None if stripped == "" else value + + +class OpenAIConfiguration(Hashable, LLMConfig): + """OpenAI Configuration class definition.""" + + # Core Configuration + _api_key: str + _model: str + + _api_base: str | None + _api_version: str | None + _cognitive_services_endpoint: str | None + _deployment_name: str | None + _organization: str | None + _proxy: str | None + + # Operation Configuration + _n: int | None + _temperature: float | None + _frequency_penalty: float | None + _presence_penalty: float | None + _top_p: float | None + _max_tokens: int | None + _response_format: str | None + _logit_bias: dict[str, float] | None + _stop: list[str] | None + + # Retry Logic + _max_retries: int | None + _max_retry_wait: float | None + _request_timeout: float | None + + # The raw configuration object + _raw_config: dict + + # Feature Flags + _model_supports_json: bool | None + + # Custom Configuration + _tokens_per_minute: int | None + _requests_per_minute: int | None + _concurrent_requests: int | None + _encoding_model: str | None + _sleep_on_rate_limit_recommendation: bool | None + + def __init__( + self, + config: dict, + ): + """Init method definition.""" + + def lookup_required(key: str) -> str: + return cast(str, config.get(key)) + + def lookup_str(key: str) -> str | None: + return cast(str | None, config.get(key)) + + def lookup_int(key: str) -> int | None: + result = config.get(key) + if result is None: + return None + return int(cast(int, result)) + + def lookup_float(key: str) -> float | None: + result = config.get(key) + if result is None: + return None + return float(cast(float, result)) + + def lookup_dict(key: str) -> dict | None: + return cast(dict | None, config.get(key)) + + def lookup_list(key: str) -> list | None: + return cast(list | None, config.get(key)) + + def lookup_bool(key: str) -> bool | None: + value = config.get(key) + if isinstance(value, str): + return value.upper() == "TRUE" + if isinstance(value, int): + return value > 0 + return cast(bool | None, config.get(key)) + + self._api_key = lookup_required("api_key") + self._model = lookup_required("model") + self._deployment_name = lookup_str("deployment_name") + self._api_base = lookup_str("api_base") + self._api_version = lookup_str("api_version") + self._cognitive_services_endpoint = lookup_str("cognitive_services_endpoint") + self._organization = lookup_str("organization") + self._proxy = lookup_str("proxy") + self._n = lookup_int("n") + self._temperature = lookup_float("temperature") + self._frequency_penalty = lookup_float("frequency_penalty") + self._presence_penalty = lookup_float("presence_penalty") + self._top_p = lookup_float("top_p") + self._max_tokens = lookup_int("max_tokens") + self._response_format = lookup_str("response_format") + self._logit_bias = lookup_dict("logit_bias") + self._stop = lookup_list("stop") + self._max_retries = lookup_int("max_retries") + self._request_timeout = lookup_float("request_timeout") + self._model_supports_json = lookup_bool("model_supports_json") + self._tokens_per_minute = lookup_int("tokens_per_minute") + self._requests_per_minute = lookup_int("requests_per_minute") + self._concurrent_requests = lookup_int("concurrent_requests") + self._encoding_model = lookup_str("encoding_model") + self._max_retry_wait = lookup_float("max_retry_wait") + self._sleep_on_rate_limit_recommendation = lookup_bool( + "sleep_on_rate_limit_recommendation" + ) + self._raw_config = config + + @property + def api_key(self) -> str: + """API key property definition.""" + return self._api_key + + @property + def model(self) -> str: + """Model property definition.""" + return self._model + + @property + def deployment_name(self) -> str | None: + """Deployment name property definition.""" + return _non_blank(self._deployment_name) + + @property + def api_base(self) -> str | None: + """API base property definition.""" + result = _non_blank(self._api_base) + # Remove trailing slash + return result[:-1] if result and result.endswith("/") else result + + @property + def api_version(self) -> str | None: + """API version property definition.""" + return _non_blank(self._api_version) + + @property + def cognitive_services_endpoint(self) -> str | None: + """API version property definition.""" + return _non_blank(self._cognitive_services_endpoint) + + @property + def organization(self) -> str | None: + """Organization property definition.""" + return _non_blank(self._organization) + + @property + def proxy(self) -> str | None: + """Proxy property definition.""" + return _non_blank(self._proxy) + + @property + def n(self) -> int | None: + """N property definition.""" + return self._n + + @property + def temperature(self) -> float | None: + """Temperature property definition.""" + return self._temperature + + @property + def frequency_penalty(self) -> float | None: + """Frequency penalty property definition.""" + return self._frequency_penalty + + @property + def presence_penalty(self) -> float | None: + """Presence penalty property definition.""" + return self._presence_penalty + + @property + def top_p(self) -> float | None: + """Top p property definition.""" + return self._top_p + + @property + def max_tokens(self) -> int | None: + """Max tokens property definition.""" + return self._max_tokens + + @property + def response_format(self) -> str | None: + """Response format property definition.""" + return _non_blank(self._response_format) + + @property + def logit_bias(self) -> dict[str, float] | None: + """Logit bias property definition.""" + return self._logit_bias + + @property + def stop(self) -> list[str] | None: + """Stop property definition.""" + return self._stop + + @property + def max_retries(self) -> int | None: + """Max retries property definition.""" + return self._max_retries + + @property + def max_retry_wait(self) -> float | None: + """Max retry wait property definition.""" + return self._max_retry_wait + + @property + def request_timeout(self) -> float | None: + """Request timeout property definition.""" + return self._request_timeout + + @property + def model_supports_json(self) -> bool | None: + """Model supports json property definition.""" + return self._model_supports_json + + @property + def tokens_per_minute(self) -> int | None: + """Tokens per minute property definition.""" + return self._tokens_per_minute + + @property + def requests_per_minute(self) -> int | None: + """Requests per minute property definition.""" + return self._requests_per_minute + + @property + def concurrent_requests(self) -> int | None: + """Concurrent requests property definition.""" + return self._concurrent_requests + + @property + def encoding_model(self) -> str | None: + """Encoding model property definition.""" + return _non_blank(self._encoding_model) + + @property + def sleep_on_rate_limit_recommendation(self) -> bool | None: + """Whether to sleep for seconds when recommended by 429 errors (azure-specific).""" + return self._sleep_on_rate_limit_recommendation + + @property + def raw_config(self) -> dict: + """Raw config method definition.""" + return self._raw_config + + def lookup(self, name: str, default_value: Any = None) -> Any: + """Lookup method definition.""" + return self._raw_config.get(name, default_value) + + def __str__(self) -> str: + """Str method definition.""" + return json.dumps(self.raw_config, indent=4) + + def __repr__(self) -> str: + """Repr method definition.""" + return f"OpenAIConfiguration({self._raw_config})" + + def __eq__(self, other: object) -> bool: + """Eq method definition.""" + if not isinstance(other, OpenAIConfiguration): + return False + return self._raw_config == other._raw_config + + def __hash__(self) -> int: + """Hash method definition.""" + return hash(tuple(sorted(self._raw_config.items()))) diff --git a/func-app/graphrag/llm/openai/openai_embeddings_llm.py b/func-app/graphrag/llm/openai/openai_embeddings_llm.py new file mode 100644 index 0000000000..558afe8437 --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_embeddings_llm.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The EmbeddingsLLM class.""" + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + EmbeddingInput, + EmbeddingOutput, + LLMInput, +) + +from .openai_configuration import OpenAIConfiguration +from .types import OpenAIClientTypes + + +class OpenAIEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]): + """A text-embedding generator LLM.""" + + _client: OpenAIClientTypes + _configuration: OpenAIConfiguration + + def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, input: EmbeddingInput, **kwargs: Unpack[LLMInput] + ) -> EmbeddingOutput | None: + args = { + "model": self.configuration.model, + **(kwargs.get("model_parameters") or {}), + } + embedding = await self.client.embeddings.create( + input=input, + **args, + ) + return [d.embedding for d in embedding.data] diff --git a/func-app/graphrag/llm/openai/openai_history_tracking_llm.py b/func-app/graphrag/llm/openai/openai_history_tracking_llm.py new file mode 100644 index 0000000000..ab903c2d2a --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_history_tracking_llm.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Chat-based language model.""" + +from typing_extensions import Unpack + +from graphrag.llm.types import ( + LLM, + CompletionInput, + CompletionLLM, + CompletionOutput, + LLMInput, + LLMOutput, +) + + +class OpenAIHistoryTrackingLLM(LLM[CompletionInput, CompletionOutput]): + """An OpenAI History-Tracking LLM.""" + + _delegate: CompletionLLM + + def __init__(self, delegate: CompletionLLM): + self._delegate = delegate + + async def __call__( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Call the LLM.""" + history = kwargs.get("history") or [] + output = await self._delegate(input, **kwargs) + return LLMOutput( + output=output.output, + json=output.json, + history=[ + *history, + {"role": "user", "content": input}, + {"role": "assistant", "content": output.output}, + ], + ) diff --git a/func-app/graphrag/llm/openai/openai_token_replacing_llm.py b/func-app/graphrag/llm/openai/openai_token_replacing_llm.py new file mode 100644 index 0000000000..7385b84059 --- /dev/null +++ b/func-app/graphrag/llm/openai/openai_token_replacing_llm.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Chat-based language model.""" + +from typing_extensions import Unpack + +from graphrag.llm.types import ( + LLM, + CompletionInput, + CompletionLLM, + CompletionOutput, + LLMInput, + LLMOutput, +) + +from .utils import perform_variable_replacements + + +class OpenAITokenReplacingLLM(LLM[CompletionInput, CompletionOutput]): + """An OpenAI History-Tracking LLM.""" + + _delegate: CompletionLLM + + def __init__(self, delegate: CompletionLLM): + self._delegate = delegate + + async def __call__( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Call the LLM with the input and kwargs.""" + variables = kwargs.get("variables") + history = kwargs.get("history") or [] + input = perform_variable_replacements(input, history, variables) + return await self._delegate(input, **kwargs) diff --git a/func-app/graphrag/llm/openai/types.py b/func-app/graphrag/llm/openai/types.py new file mode 100644 index 0000000000..4aacf18c1c --- /dev/null +++ b/func-app/graphrag/llm/openai/types.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A base class for OpenAI-based LLMs.""" + +from openai import ( + AsyncAzureOpenAI, + AsyncOpenAI, +) + +OpenAIClientTypes = AsyncOpenAI | AsyncAzureOpenAI diff --git a/func-app/graphrag/llm/openai/utils.py b/func-app/graphrag/llm/openai/utils.py new file mode 100644 index 0000000000..5d683951b1 --- /dev/null +++ b/func-app/graphrag/llm/openai/utils.py @@ -0,0 +1,160 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Utility functions for the OpenAI API.""" + +import json +import logging +import re +from collections.abc import Callable +from typing import Any + +import tiktoken +from json_repair import repair_json +from openai import ( + APIConnectionError, + InternalServerError, + RateLimitError, +) + +from .openai_configuration import OpenAIConfiguration + +DEFAULT_ENCODING = "cl100k_base" + +_encoders: dict[str, tiktoken.Encoding] = {} + +RETRYABLE_ERRORS: list[type[Exception]] = [ + RateLimitError, + APIConnectionError, + InternalServerError, +] +RATE_LIMIT_ERRORS: list[type[Exception]] = [RateLimitError] + +log = logging.getLogger(__name__) + + +def get_token_counter(config: OpenAIConfiguration) -> Callable[[str], int]: + """Get a function that counts the number of tokens in a string.""" + model = config.encoding_model or "cl100k_base" + enc = _encoders.get(model) + if enc is None: + enc = tiktoken.get_encoding(model) + _encoders[model] = enc + + return lambda s: len(enc.encode(s)) + + +def perform_variable_replacements( + input: str, history: list[dict], variables: dict | None +) -> str: + """Perform variable replacements on the input string and in a chat log.""" + result = input + + def replace_all(input: str) -> str: + result = input + if variables: + for entry in variables: + result = result.replace(f"{{{entry}}}", variables[entry]) + return result + + result = replace_all(result) + for i in range(len(history)): + entry = history[i] + if entry.get("role") == "system": + history[i]["content"] = replace_all(entry.get("content") or "") + + return result + + +def get_completion_cache_args(configuration: OpenAIConfiguration) -> dict: + """Get the cache arguments for a completion LLM.""" + return { + "model": configuration.model, + "temperature": configuration.temperature, + "frequency_penalty": configuration.frequency_penalty, + "presence_penalty": configuration.presence_penalty, + "top_p": configuration.top_p, + "max_tokens": configuration.max_tokens, + "n": configuration.n, + } + + +def get_completion_llm_args( + parameters: dict | None, configuration: OpenAIConfiguration +) -> dict: + """Get the arguments for a completion LLM.""" + return { + **get_completion_cache_args(configuration), + **(parameters or {}), + } + + +def try_parse_json_object(input: str) -> tuple[str, dict]: + """JSON cleaning and formatting utilities.""" + # Sometimes, the LLM returns a json string with some extra description, this function will clean it up. + + result = None + try: + # Try parse first + result = json.loads(input) + except json.JSONDecodeError: + log.info("Warning: Error decoding faulty json, attempting repair") + + if result: + return input, result + + _pattern = r"\{(.*)\}" + _match = re.search(_pattern, input) + input = "{" + _match.group(1) + "}" if _match else input + + # Clean up json string. + input = ( + input.replace("{{", "{") + .replace("}}", "}") + .replace('"[{', "[{") + .replace('}]"', "}]") + .replace("\\", " ") + .replace("\\n", " ") + .replace("\n", " ") + .replace("\r", "") + .strip() + ) + + # Remove JSON Markdown Frame + if input.startswith("```json"): + input = input[len("```json") :] + if input.endswith("```"): + input = input[: len(input) - len("```")] + + try: + result = json.loads(input) + except json.JSONDecodeError: + # Fixup potentially malformed json string using json_repair. + input = str(repair_json(json_str=input, return_objects=False)) + + # Generate JSON-string output using best-attempt prompting & parsing techniques. + try: + result = json.loads(input) + except json.JSONDecodeError: + log.exception("error loading json, json=%s", input) + return input, {} + else: + if not isinstance(result, dict): + log.exception("not expected dict type. type=%s:", type(result)) + return input, {} + return input, result + else: + return input, result + + +def get_sleep_time_from_error(e: Any) -> float: + """Extract the sleep time value from a RateLimitError. This is usually only available in Azure.""" + sleep_time = 0.0 + if isinstance(e, RateLimitError) and _please_retry_after in str(e): + # could be second or seconds + sleep_time = int(str(e).split(_please_retry_after)[1].split(" second")[0]) + + return sleep_time + + +_please_retry_after = "Please retry after " diff --git a/func-app/graphrag/llm/types/__init__.py b/func-app/graphrag/llm/types/__init__.py new file mode 100644 index 0000000000..c8277661d5 --- /dev/null +++ b/func-app/graphrag/llm/types/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Typings.""" + +from .llm import LLM +from .llm_cache import LLMCache +from .llm_callbacks import ( + ErrorHandlerFn, + IsResponseValidFn, + LLMInvocationFn, + OnCacheActionFn, +) +from .llm_config import LLMConfig +from .llm_invocation_result import LLMInvocationResult +from .llm_io import ( + LLMInput, + LLMOutput, +) +from .llm_types import ( + CompletionInput, + CompletionLLM, + CompletionOutput, + EmbeddingInput, + EmbeddingLLM, + EmbeddingOutput, +) + +__all__ = [ + "LLM", + "CompletionInput", + "CompletionLLM", + "CompletionOutput", + "EmbeddingInput", + "EmbeddingLLM", + "EmbeddingOutput", + "ErrorHandlerFn", + "IsResponseValidFn", + "LLMCache", + "LLMConfig", + "LLMInput", + "LLMInvocationFn", + "LLMInvocationResult", + "LLMOutput", + "OnCacheActionFn", +] diff --git a/func-app/graphrag/llm/types/llm.py b/func-app/graphrag/llm/types/llm.py new file mode 100644 index 0000000000..fd8407e50e --- /dev/null +++ b/func-app/graphrag/llm/types/llm.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Types.""" + +from typing import Generic, Protocol, TypeVar + +from typing_extensions import Unpack + +from .llm_io import ( + LLMInput, + LLMOutput, +) + +TIn = TypeVar("TIn", contravariant=True) +TOut = TypeVar("TOut") + + +class LLM(Protocol, Generic[TIn, TOut]): + """LLM Protocol definition.""" + + async def __call__( + self, + input: TIn, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[TOut]: + """Invoke the LLM, treating the LLM as a function.""" + ... diff --git a/func-app/graphrag/llm/types/llm_cache.py b/func-app/graphrag/llm/types/llm_cache.py new file mode 100644 index 0000000000..952b8d346d --- /dev/null +++ b/func-app/graphrag/llm/types/llm_cache.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Typing definitions for the OpenAI DataShaper package.""" + +from typing import Any, Protocol + + +class LLMCache(Protocol): + """LLM Cache interface.""" + + async def has(self, key: str) -> bool: + """Check if the cache has a value.""" + ... + + async def get(self, key: str) -> Any | None: + """Retrieve a value from the cache.""" + ... + + async def set(self, key: str, value: Any, debug_data: dict | None = None) -> None: + """Write a value into the cache.""" + ... diff --git a/func-app/graphrag/llm/types/llm_callbacks.py b/func-app/graphrag/llm/types/llm_callbacks.py new file mode 100644 index 0000000000..dc06dbff06 --- /dev/null +++ b/func-app/graphrag/llm/types/llm_callbacks.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Typing definitions for the OpenAI DataShaper package.""" + +from collections.abc import Callable + +from .llm_invocation_result import LLMInvocationResult + +ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] +"""Error handler function type definition.""" + +LLMInvocationFn = Callable[[LLMInvocationResult], None] +"""Handler for LLM invocation results""" + +OnCacheActionFn = Callable[[str, str | None], None] +"""Handler for cache hits""" + +IsResponseValidFn = Callable[[dict], bool] +"""A function that checks if an LLM response is valid.""" diff --git a/func-app/graphrag/llm/types/llm_config.py b/func-app/graphrag/llm/types/llm_config.py new file mode 100644 index 0000000000..cd7ec255b2 --- /dev/null +++ b/func-app/graphrag/llm/types/llm_config.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Configuration Protocol definition.""" + +from typing import Protocol + + +class LLMConfig(Protocol): + """LLM Configuration Protocol definition.""" + + @property + def max_retries(self) -> int | None: + """Get the maximum number of retries.""" + ... + + @property + def max_retry_wait(self) -> float | None: + """Get the maximum retry wait time.""" + ... + + @property + def sleep_on_rate_limit_recommendation(self) -> bool | None: + """Get whether to sleep on rate limit recommendation.""" + ... + + @property + def tokens_per_minute(self) -> int | None: + """Get the number of tokens per minute.""" + ... + + @property + def requests_per_minute(self) -> int | None: + """Get the number of requests per minute.""" + ... diff --git a/func-app/graphrag/llm/types/llm_invocation_result.py b/func-app/graphrag/llm/types/llm_invocation_result.py new file mode 100644 index 0000000000..1769aeb96d --- /dev/null +++ b/func-app/graphrag/llm/types/llm_invocation_result.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Typing definitions for the OpenAI DataShaper package.""" + +from dataclasses import dataclass +from typing import Generic, TypeVar + +T = TypeVar("T") + + +@dataclass +class LLMInvocationResult(Generic[T]): + """The result of an LLM invocation.""" + + result: T | None + """The result of the LLM invocation.""" + + name: str + """The operation name of the result""" + + num_retries: int + """The number of retries the invocation took.""" + + total_time: float + """The total time of the LLM invocation.""" + + call_times: list[float] + """The network times of individual invocations.""" + + input_tokens: int + """The number of input tokens.""" + + output_tokens: int + """The number of output tokens.""" diff --git a/func-app/graphrag/llm/types/llm_io.py b/func-app/graphrag/llm/types/llm_io.py new file mode 100644 index 0000000000..256f3c8ce8 --- /dev/null +++ b/func-app/graphrag/llm/types/llm_io.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Types.""" + +from dataclasses import dataclass, field +from typing import Generic, TypeVar + +from typing_extensions import NotRequired, TypedDict + +from .llm_callbacks import IsResponseValidFn + + +class LLMInput(TypedDict): + """The input of an LLM invocation.""" + + name: NotRequired[str] + """The name of the LLM invocation, if available.""" + + json: NotRequired[bool] + """If true, will attempt to elicit JSON from the LLM. Parsed JSON will be returned in the `json_output` field.""" + + is_response_valid: NotRequired[IsResponseValidFn] + """A function that checks if an LLM response is valid. Only valid if `json=True`.""" + + variables: NotRequired[dict] + """The variable replacements to use in the prompt.""" + + history: NotRequired[list[dict] | None] + """The history of the LLM invocation, if available (e.g. chat mode)""" + + model_parameters: NotRequired[dict] + """Additional model parameters to use in the LLM invocation.""" + + +T = TypeVar("T") + + +@dataclass +class LLMOutput(Generic[T]): + """The output of an LLM invocation.""" + + output: T | None + """The output of the LLM invocation.""" + + json: dict | None = field(default=None) + """The JSON output from the LLM, if available.""" + + history: list[dict] | None = field(default=None) + """The history of the LLM invocation, if available (e.g. chat mode)""" diff --git a/func-app/graphrag/llm/types/llm_types.py b/func-app/graphrag/llm/types/llm_types.py new file mode 100644 index 0000000000..7ae76ef9be --- /dev/null +++ b/func-app/graphrag/llm/types/llm_types.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LLM Types.""" + +from typing import TypeAlias + +from .llm import LLM + +EmbeddingInput: TypeAlias = list[str] +EmbeddingOutput: TypeAlias = list[list[float]] +CompletionInput: TypeAlias = str +CompletionOutput: TypeAlias = str + +EmbeddingLLM: TypeAlias = LLM[EmbeddingInput, EmbeddingOutput] +CompletionLLM: TypeAlias = LLM[CompletionInput, CompletionOutput] diff --git a/func-app/graphrag/model/__init__.py b/func-app/graphrag/model/__init__.py new file mode 100644 index 0000000000..9dbec3d1dd --- /dev/null +++ b/func-app/graphrag/model/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +""" +GraphRAG knowledge model package root. + +The GraphRAG knowledge model contains a set of classes that represent the target datamodels for our pipelines and analytics tools. +These models can be augmented and integrated into your own data infrastructure to suit your needs. +""" + +from .community import Community +from .community_report import CommunityReport +from .covariate import Covariate +from .document import Document +from .entity import Entity +from .identified import Identified +from .named import Named +from .relationship import Relationship +from .text_unit import TextUnit + +__all__ = [ + "Community", + "CommunityReport", + "Covariate", + "Document", + "Entity", + "Identified", + "Named", + "Relationship", + "TextUnit", +] diff --git a/func-app/graphrag/model/community.py b/func-app/graphrag/model/community.py new file mode 100644 index 0000000000..800a9a292a --- /dev/null +++ b/func-app/graphrag/model/community.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Community' model.""" + +from dataclasses import dataclass +from typing import Any + +from .named import Named + + +@dataclass +class Community(Named): + """A protocol for a community in the system.""" + + level: str = "" + """Community level.""" + + entity_ids: list[str] | None = None + """List of entity IDs related to the community (optional).""" + + relationship_ids: list[str] | None = None + """List of relationship IDs related to the community (optional).""" + + covariate_ids: dict[str, list[str]] | None = None + """Dictionary of different types of covariates related to the community (optional), e.g. claims""" + + attributes: dict[str, Any] | None = None + """A dictionary of additional attributes associated with the community (optional). To be included in the search prompt.""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + title_key: str = "title", + short_id_key: str = "short_id", + level_key: str = "level", + entities_key: str = "entity_ids", + relationships_key: str = "relationship_ids", + covariates_key: str = "covariate_ids", + attributes_key: str = "attributes", + ) -> "Community": + """Create a new community from the dict data.""" + return Community( + id=d[id_key], + title=d[title_key], + short_id=d.get(short_id_key), + level=d[level_key], + entity_ids=d.get(entities_key), + relationship_ids=d.get(relationships_key), + covariate_ids=d.get(covariates_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/community_report.py b/func-app/graphrag/model/community_report.py new file mode 100644 index 0000000000..2666c0b5a8 --- /dev/null +++ b/func-app/graphrag/model/community_report.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'CommunityReport' model.""" + +from dataclasses import dataclass +from typing import Any + +from .named import Named + + +@dataclass +class CommunityReport(Named): + """Defines an LLM-generated summary report of a community.""" + + community_id: str + """The ID of the community this report is associated with.""" + + summary: str = "" + """Summary of the report.""" + + full_content: str = "" + """Full content of the report.""" + + rank: float | None = 1.0 + """Rank of the report, used for sorting (optional). Higher means more important""" + + summary_embedding: list[float] | None = None + """The semantic (i.e. text) embedding of the report summary (optional).""" + + full_content_embedding: list[float] | None = None + """The semantic (i.e. text) embedding of the full report content (optional).""" + + attributes: dict[str, Any] | None = None + """A dictionary of additional attributes associated with the report (optional).""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + title_key: str = "title", + community_id_key: str = "community_id", + short_id_key: str = "short_id", + summary_key: str = "summary", + full_content_key: str = "full_content", + rank_key: str = "rank", + summary_embedding_key: str = "summary_embedding", + full_content_embedding_key: str = "full_content_embedding", + attributes_key: str = "attributes", + ) -> "CommunityReport": + """Create a new community report from the dict data.""" + return CommunityReport( + id=d[id_key], + title=d[title_key], + community_id=d[community_id_key], + short_id=d.get(short_id_key), + summary=d[summary_key], + full_content=d[full_content_key], + rank=d[rank_key], + summary_embedding=d.get(summary_embedding_key), + full_content_embedding=d.get(full_content_embedding_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/covariate.py b/func-app/graphrag/model/covariate.py new file mode 100644 index 0000000000..b974b6b327 --- /dev/null +++ b/func-app/graphrag/model/covariate.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Covariate' model.""" + +from dataclasses import dataclass +from typing import Any + +from .identified import Identified + + +@dataclass +class Covariate(Identified): + """ + A protocol for a covariate in the system. + + Covariates are metadata associated with a subject, e.g. entity claims. + Each subject (e.g. entity) may be associated with multiple types of covariates. + """ + + subject_id: str + """The subject id.""" + + subject_type: str = "entity" + """The subject type.""" + + covariate_type: str = "claim" + """The covariate type.""" + + text_unit_ids: list[str] | None = None + """List of text unit IDs in which the covariate info appears (optional).""" + + document_ids: list[str] | None = None + """List of document IDs in which the covariate info appears (optional).""" + + attributes: dict[str, Any] | None = None + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + subject_id_key: str = "subject_id", + subject_type_key: str = "subject_type", + covariate_type_key: str = "covariate_type", + short_id_key: str = "short_id", + text_unit_ids_key: str = "text_unit_ids", + document_ids_key: str = "document_ids", + attributes_key: str = "attributes", + ) -> "Covariate": + """Create a new covariate from the dict data.""" + return Covariate( + id=d[id_key], + short_id=d.get(short_id_key), + subject_id=d[subject_id_key], + subject_type=d.get(subject_type_key, "entity"), + covariate_type=d.get(covariate_type_key, "claim"), + text_unit_ids=d.get(text_unit_ids_key), + document_ids=d.get(document_ids_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/document.py b/func-app/graphrag/model/document.py new file mode 100644 index 0000000000..b54a39ac91 --- /dev/null +++ b/func-app/graphrag/model/document.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Document' model.""" + +from dataclasses import dataclass, field +from typing import Any + +from .named import Named + + +@dataclass +class Document(Named): + """A protocol for a document in the system.""" + + type: str = "text" + """Type of the document.""" + + text_unit_ids: list[str] = field(default_factory=list) + """list of text units in the document.""" + + raw_content: str = "" + """The raw text content of the document.""" + + summary: str | None = None + """Summary of the document (optional).""" + + summary_embedding: list[float] | None = None + """The semantic embedding for the document summary (optional).""" + + raw_content_embedding: list[float] | None = None + """The semantic embedding for the document raw content (optional).""" + + attributes: dict[str, Any] | None = None + """A dictionary of structured attributes such as author, etc (optional).""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + short_id_key: str = "short_id", + title_key: str = "title", + type_key: str = "type", + raw_content_key: str = "raw_content", + summary_key: str = "summary", + summary_embedding_key: str = "summary_embedding", + raw_content_embedding_key: str = "raw_content_embedding", + text_units_key: str = "text_units", + attributes_key: str = "attributes", + ) -> "Document": + """Create a new document from the dict data.""" + return Document( + id=d[id_key], + short_id=d.get(short_id_key), + title=d[title_key], + type=d.get(type_key, "text"), + raw_content=d[raw_content_key], + summary=d.get(summary_key), + summary_embedding=d.get(summary_embedding_key), + raw_content_embedding=d.get(raw_content_embedding_key), + text_unit_ids=d.get(text_units_key, []), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/entity.py b/func-app/graphrag/model/entity.py new file mode 100644 index 0000000000..37c26342aa --- /dev/null +++ b/func-app/graphrag/model/entity.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Entity' model.""" + +from dataclasses import dataclass +from typing import Any + +from .named import Named + + +@dataclass +class Entity(Named): + """A protocol for an entity in the system.""" + + type: str | None = None + """Type of the entity (can be any string, optional).""" + + description: str | None = None + """Description of the entity (optional).""" + + description_embedding: list[float] | None = None + """The semantic (i.e. text) embedding of the entity (optional).""" + + name_embedding: list[float] | None = None + """The semantic (i.e. text) embedding of the entity (optional).""" + + graph_embedding: list[float] | None = None + """The graph embedding of the entity, likely from node2vec (optional).""" + + community_ids: list[str] | None = None + """The community IDs of the entity (optional).""" + + text_unit_ids: list[str] | None = None + """List of text unit IDs in which the entity appears (optional).""" + + document_ids: list[str] | None = None + """List of document IDs in which the entity appears (optional).""" + + rank: int | None = 1 + """Rank of the entity, used for sorting (optional). Higher rank indicates more important entity. This can be based on centrality or other metrics.""" + + attributes: dict[str, Any] | None = None + """Additional attributes associated with the entity (optional), e.g. start time, end time, etc. To be included in the search prompt.""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + short_id_key: str = "short_id", + title_key: str = "title", + type_key: str = "type", + description_key: str = "description", + description_embedding_key: str = "description_embedding", + name_embedding_key: str = "name_embedding", + graph_embedding_key: str = "graph_embedding", + community_key: str = "community", + text_unit_ids_key: str = "text_unit_ids", + document_ids_key: str = "document_ids", + rank_key: str = "degree", + attributes_key: str = "attributes", + ) -> "Entity": + """Create a new entity from the dict data.""" + return Entity( + id=d[id_key], + title=d[title_key], + short_id=d.get(short_id_key), + type=d.get(type_key), + description=d.get(description_key), + name_embedding=d.get(name_embedding_key), + description_embedding=d.get(description_embedding_key), + graph_embedding=d.get(graph_embedding_key), + community_ids=d.get(community_key), + rank=d.get(rank_key, 1), + text_unit_ids=d.get(text_unit_ids_key), + document_ids=d.get(document_ids_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/identified.py b/func-app/graphrag/model/identified.py new file mode 100644 index 0000000000..ca2c939526 --- /dev/null +++ b/func-app/graphrag/model/identified.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Identified' protocol.""" + +from dataclasses import dataclass + + +@dataclass +class Identified: + """A protocol for an item with an ID.""" + + id: str + """The ID of the item.""" + + short_id: str | None + """Human readable ID used to refer to this community in prompts or texts displayed to users, such as in a report text (optional).""" diff --git a/func-app/graphrag/model/named.py b/func-app/graphrag/model/named.py new file mode 100644 index 0000000000..5352c77c96 --- /dev/null +++ b/func-app/graphrag/model/named.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Named' protocol.""" + +from dataclasses import dataclass + +from .identified import Identified + + +@dataclass +class Named(Identified): + """A protocol for an item with a name/title.""" + + title: str + """The name/title of the item.""" diff --git a/func-app/graphrag/model/relationship.py b/func-app/graphrag/model/relationship.py new file mode 100644 index 0000000000..fadd0aaa6f --- /dev/null +++ b/func-app/graphrag/model/relationship.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'Relationship' model.""" + +from dataclasses import dataclass +from typing import Any + +from .identified import Identified + + +@dataclass +class Relationship(Identified): + """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" + + source: str + """The source entity name.""" + + target: str + """The target entity name.""" + + weight: float | None = 1.0 + """The edge weight.""" + + description: str | None = None + """A description of the relationship (optional).""" + + description_embedding: list[float] | None = None + """The semantic embedding for the relationship description (optional).""" + + text_unit_ids: list[str] | None = None + """List of text unit IDs in which the relationship appears (optional).""" + + document_ids: list[str] | None = None + """List of document IDs in which the relationship appears (optional).""" + + attributes: dict[str, Any] | None = None + """Additional attributes associated with the relationship (optional). To be included in the search prompt""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + short_id_key: str = "short_id", + source_key: str = "source", + target_key: str = "target", + description_key: str = "description", + weight_key: str = "weight", + text_unit_ids_key: str = "text_unit_ids", + document_ids_key: str = "document_ids", + attributes_key: str = "attributes", + ) -> "Relationship": + """Create a new relationship from the dict data.""" + return Relationship( + id=d[id_key], + short_id=d.get(short_id_key), + source=d[source_key], + target=d[target_key], + description=d.get(description_key), + weight=d.get(weight_key, 1.0), + text_unit_ids=d.get(text_unit_ids_key), + document_ids=d.get(document_ids_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/text_unit.py b/func-app/graphrag/model/text_unit.py new file mode 100644 index 0000000000..cff4ac01c1 --- /dev/null +++ b/func-app/graphrag/model/text_unit.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the 'TextUnit' model.""" + +from dataclasses import dataclass +from typing import Any + +from .identified import Identified + + +@dataclass +class TextUnit(Identified): + """A protocol for a TextUnit item in a Document database.""" + + text: str + """The text of the unit.""" + + text_embedding: list[float] | None = None + """The text embedding for the text unit (optional).""" + + entity_ids: list[str] | None = None + """List of entity IDs related to the text unit (optional).""" + + relationship_ids: list[str] | None = None + """List of relationship IDs related to the text unit (optional).""" + + covariate_ids: dict[str, list[str]] | None = None + "Dictionary of different types of covariates related to the text unit (optional)." + + n_tokens: int | None = None + """The number of tokens in the text (optional).""" + + document_ids: list[str] | None = None + """List of document IDs in which the text unit appears (optional).""" + + attributes: dict[str, Any] | None = None + """A dictionary of additional attributes associated with the text unit (optional).""" + + @classmethod + def from_dict( + cls, + d: dict[str, Any], + id_key: str = "id", + short_id_key: str = "short_id", + text_key: str = "text", + text_embedding_key: str = "text_embedding", + entities_key: str = "entity_ids", + relationships_key: str = "relationship_ids", + covariates_key: str = "covariate_ids", + n_tokens_key: str = "n_tokens", + document_ids_key: str = "document_ids", + attributes_key: str = "attributes", + ) -> "TextUnit": + """Create a new text unit from the dict data.""" + return TextUnit( + id=d[id_key], + short_id=d.get(short_id_key), + text=d[text_key], + text_embedding=d.get(text_embedding_key), + entity_ids=d.get(entities_key), + relationship_ids=d.get(relationships_key), + covariate_ids=d.get(covariates_key), + n_tokens=d.get(n_tokens_key), + document_ids=d.get(document_ids_key), + attributes=d.get(attributes_key), + ) diff --git a/func-app/graphrag/model/types.py b/func-app/graphrag/model/types.py new file mode 100644 index 0000000000..6156e39969 --- /dev/null +++ b/func-app/graphrag/model/types.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Common types for the GraphRAG knowledge model.""" + +from collections.abc import Callable + +TextEmbedder = Callable[[str], list[float]] diff --git a/func-app/graphrag/prompt_tune/__init__.py b/func-app/graphrag/prompt_tune/__init__.py new file mode 100644 index 0000000000..2384b5793c --- /dev/null +++ b/func-app/graphrag/prompt_tune/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Command line interface for the fine_tune module.""" diff --git a/func-app/graphrag/prompt_tune/__main__.py b/func-app/graphrag/prompt_tune/__main__.py new file mode 100644 index 0000000000..e752b05a8f --- /dev/null +++ b/func-app/graphrag/prompt_tune/__main__.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Prompt auto templating package root.""" + +import argparse +import asyncio +from enum import Enum + +from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT +from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE + +from .cli import prompt_tune + + +class DocSelectionType(Enum): + """The type of document selection to use.""" + + ALL = "all" + RANDOM = "random" + TOP = "top" + AUTO = "auto" + + def __str__(self): + """Return the string representation of the enum value.""" + return self.value + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--root", + help="The data project root. Including the config yml, json or .env", + required=False, + type=str, + default=".", + ) + + parser.add_argument( + "--domain", + help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If left empty, the domain will be inferred from the input data.", + required=False, + default="", + type=str, + ) + + parser.add_argument( + "--method", + help="The method to select documents, one of: all, random, top or auto", + required=False, + type=DocSelectionType, + choices=list(DocSelectionType), + default=DocSelectionType.RANDOM, + ) + + parser.add_argument( + "--n_subset_max", + help="The number of text chunks to embed when using auto selection method", + required=False, + type=int, + default=300, + ) + + parser.add_argument( + "--k", + help="The maximum number of documents to select from each centroid when using auto selection method", + required=False, + type=int, + default=15, + ) + + parser.add_argument( + "--limit", + help="The limit of files to load when doing random or top selection", + type=int, + required=False, + default=15, + ) + + parser.add_argument( + "--max-tokens", + help="Max token count for prompt generation", + type=int, + required=False, + default=MAX_TOKEN_COUNT, + ) + + parser.add_argument( + "--min-examples-required", + help="The minimum number of examples required in entity extraction prompt", + type=int, + required=False, + default=2, + ) + + parser.add_argument( + "--chunk-size", + help="Max token count for prompt generation", + type=int, + required=False, + default=MIN_CHUNK_SIZE, + ) + + parser.add_argument( + "--language", + help="Primary language used for inputs and outputs on GraphRAG", + type=str, + required=False, + default="", + ) + + parser.add_argument( + "--no-entity-types", + help="Use untyped entity extraction generation", + action="store_true", + required=False, + default=False, + ) + + parser.add_argument( + "--output", + help="Folder to save the generated prompts to", + type=str, + required=False, + default="prompts", + ) + + args = parser.parse_args() + + loop = asyncio.get_event_loop() + + loop.run_until_complete( + prompt_tune( + args.root, + args.domain, + str(args.method), + args.limit, + args.max_tokens, + args.chunk_size, + args.language, + args.no_entity_types, + args.output, + args.n_subset_max, + args.k, + args.min_examples_required, + ) + ) diff --git a/func-app/graphrag/prompt_tune/cli.py b/func-app/graphrag/prompt_tune/cli.py new file mode 100644 index 0000000000..5979a4a6ee --- /dev/null +++ b/func-app/graphrag/prompt_tune/cli.py @@ -0,0 +1,272 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Command line interface for the fine_tune module.""" + +from pathlib import Path + +from datashaper import NoopVerbCallbacks + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.llm import load_llm +from graphrag.index.progress import PrintProgressReporter +from graphrag.index.progress.types import ProgressReporter +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.generator import ( + MAX_TOKEN_COUNT, + create_community_summarization_prompt, + create_entity_extraction_prompt, + create_entity_summarization_prompt, + detect_language, + generate_community_report_rating, + generate_community_reporter_role, + generate_domain, + generate_entity_relationship_examples, + generate_entity_types, + generate_persona, +) +from graphrag.prompt_tune.loader import ( + MIN_CHUNK_SIZE, + load_docs_in_chunks, + read_config_parameters, +) + + +async def prompt_tune( + root: str, + domain: str, + select: str = "random", + limit: int = 15, + max_tokens: int = MAX_TOKEN_COUNT, + chunk_size: int = MIN_CHUNK_SIZE, + language: str | None = None, + skip_entity_types: bool = False, + output: str = "prompts", + n_subset_max: int = 300, + k: int = 15, + min_examples_required: int = 2, +): + """Prompt tune the model. + + Parameters + ---------- + - root: The root directory. + - domain: The domain to map the input documents to. + - select: The chunk selection method. + - limit: The limit of chunks to load. + - max_tokens: The maximum number of tokens to use on entity extraction prompts. + - chunk_size: The chunk token size to use. + - skip_entity_types: Skip generating entity types. + - output: The output folder to store the prompts. + - n_subset_max: The number of text chunks to embed when using auto selection method. + - k: The number of documents to select when using auto selection method. + """ + reporter = PrintProgressReporter("") + config = read_config_parameters(root, reporter) + + await prompt_tune_with_config( + root, + config, + domain, + select, + limit, + max_tokens, + chunk_size, + language, + skip_entity_types, + output, + reporter, + n_subset_max, + k, + min_examples_required, + ) + + +async def prompt_tune_with_config( + root: str, + config: GraphRagConfig, + domain: str, + select: str = "random", + limit: int = 15, + max_tokens: int = MAX_TOKEN_COUNT, + chunk_size: int = MIN_CHUNK_SIZE, + language: str | None = None, + skip_entity_types: bool = False, + output: str = "prompts", + reporter: ProgressReporter | None = None, + n_subset_max: int = 300, + k: int = 15, + min_examples_required: int = 2, +): + """Prompt tune the model with a configuration. + + Parameters + ---------- + - root: The root directory. + - config: The GraphRag configuration. + - domain: The domain to map the input documents to. + - select: The chunk selection method. + - limit: The limit of chunks to load. + - max_tokens: The maximum number of tokens to use on entity extraction prompts. + - chunk_size: The chunk token size to use for input text units. + - skip_entity_types: Skip generating entity types. + - output: The output folder to store the prompts. + - reporter: The progress reporter. + - n_subset_max: The number of text chunks to embed when using auto selection method. + - k: The number of documents to select when using auto selection method. + + Returns + ------- + - None + """ + if not reporter: + reporter = PrintProgressReporter("") + + output_path = Path(config.root_dir) / output + + doc_list = await load_docs_in_chunks( + root=root, + config=config, + limit=limit, + select_method=select, + reporter=reporter, + chunk_size=chunk_size, + n_subset_max=n_subset_max, + k=k, + ) + + # Create LLM from config + llm = load_llm( + "prompt_tuning", + config.llm.type, + NoopVerbCallbacks(), + None, + config.llm.model_dump(), + ) + + await generate_indexing_prompts( + llm, + config, + doc_list, + output_path, + reporter, + domain, + language, + max_tokens, + skip_entity_types, + min_examples_required, + ) + + +async def generate_indexing_prompts( + llm: CompletionLLM, + config: GraphRagConfig, + doc_list: list[str], + output_path: Path, + reporter: ProgressReporter, + domain: str | None = None, + language: str | None = None, + max_tokens: int = MAX_TOKEN_COUNT, + skip_entity_types: bool = False, + min_examples_required: int = 2, +): + """Generate indexing prompts. + + Parameters + ---------- + - llm: The LLM model to use. + - config: The GraphRag configuration. + - doc_list: The list of documents to use. + - output_path: The path to store the prompts. + - reporter: The progress reporter. + - domain: The domain to map the input documents to. + - max_tokens: The maximum number of tokens to use on entity extraction prompts + - skip_entity_types: Skip generating entity types. + - min_examples_required: The minimum number of examples required for entity extraction prompts. + """ + if not domain: + reporter.info("Generating domain...") + domain = await generate_domain(llm, doc_list) + reporter.info(f"Generated domain: {domain}") + + if not language: + reporter.info("Detecting language...") + language = await detect_language(llm, doc_list) + reporter.info(f"Detected language: {language}") + + reporter.info("Generating persona...") + persona = await generate_persona(llm, domain) + reporter.info(f"Generated persona: {persona}") + + reporter.info("Generating community report ranking description...") + community_report_ranking = await generate_community_report_rating( + llm, domain=domain, persona=persona, docs=doc_list + ) + reporter.info( + f"Generated community report ranking description: {community_report_ranking}" + ) + + entity_types = None + if not skip_entity_types: + reporter.info("Generating entity types") + entity_types = await generate_entity_types( + llm, + domain=domain, + persona=persona, + docs=doc_list, + json_mode=config.llm.model_supports_json or False, + ) + reporter.info(f"Generated entity types: {entity_types}") + + reporter.info("Generating entity relationship examples...") + examples = await generate_entity_relationship_examples( + llm, + persona=persona, + entity_types=entity_types, + docs=doc_list, + language=language, + json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine + ) + reporter.info("Done generating entity relationship examples") + + reporter.info("Generating entity extraction prompt...") + create_entity_extraction_prompt( + entity_types=entity_types, + docs=doc_list, + examples=examples, + language=language, + json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine + output_path=output_path, + encoding_model=config.encoding_model, + max_token_count=max_tokens, + min_examples_required=min_examples_required, + ) + reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}") + + reporter.info("Generating entity summarization prompt...") + create_entity_summarization_prompt( + persona=persona, + language=language, + output_path=output_path, + ) + reporter.info( + f"Generated entity summarization prompt, stored in folder {output_path}" + ) + + reporter.info("Generating community reporter role...") + community_reporter_role = await generate_community_reporter_role( + llm, domain=domain, persona=persona, docs=doc_list + ) + reporter.info(f"Generated community reporter role: {community_reporter_role}") + + reporter.info("Generating community summarization prompt...") + create_community_summarization_prompt( + persona=persona, + role=community_reporter_role, + report_rating_description=community_report_ranking, + language=language, + output_path=output_path, + ) + reporter.info( + f"Generated community summarization prompt, stored in folder {output_path}" + ) diff --git a/func-app/graphrag/prompt_tune/generator/__init__.py b/func-app/graphrag/prompt_tune/generator/__init__.py new file mode 100644 index 0000000000..df45b46033 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Prompt generation module.""" + +from .community_report_rating import generate_community_report_rating +from .community_report_summarization import create_community_summarization_prompt +from .community_reporter_role import generate_community_reporter_role +from .defaults import MAX_TOKEN_COUNT +from .domain import generate_domain +from .entity_extraction_prompt import create_entity_extraction_prompt +from .entity_relationship import generate_entity_relationship_examples +from .entity_summarization_prompt import create_entity_summarization_prompt +from .entity_types import generate_entity_types +from .language import detect_language +from .persona import generate_persona + +__all__ = [ + "MAX_TOKEN_COUNT", + "create_community_summarization_prompt", + "create_entity_extraction_prompt", + "create_entity_summarization_prompt", + "detect_language", + "generate_community_report_rating", + "generate_community_reporter_role", + "generate_domain", + "generate_entity_relationship_examples", + "generate_entity_types", + "generate_persona", +] diff --git a/func-app/graphrag/prompt_tune/generator/community_report_rating.py b/func-app/graphrag/prompt_tune/generator/community_report_rating.py new file mode 100644 index 0000000000..59f94d5698 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/community_report_rating.py @@ -0,0 +1,35 @@ +"""Generate a rating description for community report rating.""" + +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.prompt import ( + GENERATE_REPORT_RATING_PROMPT, +) + + +async def generate_community_report_rating( + llm: CompletionLLM, domain: str, persona: str, docs: str | list[str] +) -> str: + """Generate an LLM persona to use for GraphRAG prompts. + + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - domain (str): The domain to generate a rating for + - persona (str): The persona to generate a rating for for + - docs (str | list[str]): Documents used to contextualize the rating + + Returns + ------- + - str: The generated rating description prompt response. + """ + docs_str = " ".join(docs) if isinstance(docs, list) else docs + domain_prompt = GENERATE_REPORT_RATING_PROMPT.format( + domain=domain, persona=persona, input_text=docs_str + ) + + response = await llm(domain_prompt) + + return str(response.output).strip() diff --git a/func-app/graphrag/prompt_tune/generator/community_report_summarization.py b/func-app/graphrag/prompt_tune/generator/community_report_summarization.py new file mode 100644 index 0000000000..b0c0b614d2 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/community_report_summarization.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Module for generating prompts for community report summarization.""" + +from pathlib import Path + +from graphrag.prompt_tune.template import COMMUNITY_REPORT_SUMMARIZATION_PROMPT + +COMMUNITY_SUMMARIZATION_FILENAME = "community_report.txt" + + +def create_community_summarization_prompt( + persona: str, + role: str, + report_rating_description: str, + language: str, + output_path: Path | None = None, +) -> str: + """Create a prompt for community summarization. If output_path is provided, write the prompt to a file. + + Parameters + ---------- + - persona (str): The persona to use for the community summarization prompt + - role (str): The role to use for the community summarization prompt + - language (str): The language to use for the community summarization prompt + - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + + Returns + ------- + - str: The community summarization prompt + """ + prompt = COMMUNITY_REPORT_SUMMARIZATION_PROMPT.format( + persona=persona, + role=role, + report_rating_description=report_rating_description, + language=language, + ) + + if output_path: + output_path.mkdir(parents=True, exist_ok=True) + + output_path = output_path / COMMUNITY_SUMMARIZATION_FILENAME + # Write file to output path + with output_path.open("wb") as file: + file.write(prompt.encode(encoding="utf-8", errors="strict")) + + return prompt diff --git a/func-app/graphrag/prompt_tune/generator/community_reporter_role.py b/func-app/graphrag/prompt_tune/generator/community_reporter_role.py new file mode 100644 index 0000000000..9abd5ed83f --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/community_reporter_role.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Generate a community reporter role for community summarization.""" + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.prompt import ( + GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT, +) + + +async def generate_community_reporter_role( + llm: CompletionLLM, domain: str, persona: str, docs: str | list[str] +) -> str: + """Generate an LLM persona to use for GraphRAG prompts. + + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - domain (str): The domain to generate a persona for + - persona (str): The persona to generate a role for + - docs (str | list[str]): The domain to generate a persona for + + Returns + ------- + - str: The generated domain prompt response. + """ + docs_str = " ".join(docs) if isinstance(docs, list) else docs + domain_prompt = GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT.format( + domain=domain, persona=persona, input_text=docs_str + ) + + response = await llm(domain_prompt) + + return str(response.output) diff --git a/func-app/graphrag/prompt_tune/generator/defaults.py b/func-app/graphrag/prompt_tune/generator/defaults.py new file mode 100644 index 0000000000..5b42f81332 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/defaults.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Default values for the fine-tuning module.""" + +DEFAULT_TASK = """ +Identify the relations and structure of the community of interest, specifically within the {domain} domain. +""" + +MAX_TOKEN_COUNT = 2000 diff --git a/func-app/graphrag/prompt_tune/generator/domain.py b/func-app/graphrag/prompt_tune/generator/domain.py new file mode 100644 index 0000000000..49c698d1b4 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/domain.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Domain generation for GraphRAG prompts.""" + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.prompt.domain import GENERATE_DOMAIN_PROMPT + + +async def generate_domain(llm: CompletionLLM, docs: str | list[str]) -> str: + """Generate an LLM persona to use for GraphRAG prompts. + + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - docs (str | list[str]): The domain to generate a persona for + + Returns + ------- + - str: The generated domain prompt response. + """ + docs_str = " ".join(docs) if isinstance(docs, list) else docs + domain_prompt = GENERATE_DOMAIN_PROMPT.format(input_text=docs_str) + + response = await llm(domain_prompt) + + return str(response.output) diff --git a/func-app/graphrag/prompt_tune/generator/entity_extraction_prompt.py b/func-app/graphrag/prompt_tune/generator/entity_extraction_prompt.py new file mode 100644 index 0000000000..3b17dbab5d --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/entity_extraction_prompt.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Entity Extraction prompt generator module.""" + +from pathlib import Path + +import graphrag.config.defaults as defs +from graphrag.index.utils.tokens import num_tokens_from_string +from graphrag.prompt_tune.template import ( + EXAMPLE_EXTRACTION_TEMPLATE, + GRAPH_EXTRACTION_JSON_PROMPT, + GRAPH_EXTRACTION_PROMPT, + UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE, + UNTYPED_GRAPH_EXTRACTION_PROMPT, +) + +ENTITY_EXTRACTION_FILENAME = "entity_extraction.txt" + + +def create_entity_extraction_prompt( + entity_types: str | list[str] | None, + docs: list[str], + examples: list[str], + language: str, + max_token_count: int, + encoding_model: str = defs.ENCODING_MODEL, + json_mode: bool = False, + output_path: Path | None = None, + min_examples_required: int = 2, +) -> str: + """ + Create a prompt for entity extraction. + + Parameters + ---------- + - entity_types (str | list[str]): The entity types to extract + - docs (list[str]): The list of documents to extract entities from + - examples (list[str]): The list of examples to use for entity extraction + - language (str): The language of the inputs and outputs + - encoding_model (str): The name of the model to use for token counting + - max_token_count (int): The maximum number of tokens to use for the prompt + - json_mode (bool): Whether to use JSON mode for the prompt. Default is False + - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + - min_examples_required (int): The minimum number of examples required. Default is 2. + + Returns + ------- + - str: The entity extraction prompt + """ + prompt = ( + (GRAPH_EXTRACTION_JSON_PROMPT if json_mode else GRAPH_EXTRACTION_PROMPT) + if entity_types + else UNTYPED_GRAPH_EXTRACTION_PROMPT + ) + if isinstance(entity_types, list): + entity_types = ", ".join(entity_types) + + tokens_left = ( + max_token_count + - num_tokens_from_string(prompt, model=encoding_model) + - num_tokens_from_string(entity_types, model=encoding_model) + if entity_types + else 0 + ) + + examples_prompt = "" + + # Iterate over examples, while we have tokens left or examples left + for i, output in enumerate(examples): + input = docs[i] + example_formatted = ( + EXAMPLE_EXTRACTION_TEMPLATE.format( + n=i + 1, input_text=input, entity_types=entity_types, output=output + ) + if entity_types + else UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE.format( + n=i + 1, input_text=input, output=output + ) + ) + + example_tokens = num_tokens_from_string(example_formatted, model=encoding_model) + + # Ensure at least three examples are included + if i >= min_examples_required and example_tokens > tokens_left: + break + + examples_prompt += example_formatted + tokens_left -= example_tokens + + prompt = ( + prompt.format( + entity_types=entity_types, examples=examples_prompt, language=language + ) + if entity_types + else prompt.format(examples=examples_prompt, language=language) + ) + + if output_path: + output_path.mkdir(parents=True, exist_ok=True) + + output_path = output_path / ENTITY_EXTRACTION_FILENAME + # Write file to output path + with output_path.open("wb") as file: + file.write(prompt.encode(encoding="utf-8", errors="strict")) + + return prompt diff --git a/func-app/graphrag/prompt_tune/generator/entity_relationship.py b/func-app/graphrag/prompt_tune/generator/entity_relationship.py new file mode 100644 index 0000000000..72ecb5f4da --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/entity_relationship.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Entity relationship example generation module.""" + +import asyncio +import json + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.prompt import ( + ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT, + ENTITY_RELATIONSHIPS_GENERATION_PROMPT, + UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT, +) + +MAX_EXAMPLES = 5 + + +async def generate_entity_relationship_examples( + llm: CompletionLLM, + persona: str, + entity_types: str | list[str] | None, + docs: str | list[str], + language: str, + json_mode: bool = False, +) -> list[str]: + """Generate a list of entity/relationships examples for use in generating an entity configuration. + + Will return entity/relationships examples as either JSON or in tuple_delimiter format depending + on the json_mode parameter. + """ + docs_list = [docs] if isinstance(docs, str) else docs + history = [{"role": "system", "content": persona}] + + if entity_types: + entity_types_str = ( + entity_types if isinstance(entity_types, str) else ", ".join(entity_types) + ) + + messages = [ + ( + ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT + if json_mode + else ENTITY_RELATIONSHIPS_GENERATION_PROMPT + ).format(entity_types=entity_types_str, input_text=doc, language=language) + for doc in docs_list + ] + else: + messages = [ + UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT.format( + input_text=doc, language=language + ) + for doc in docs_list + ] + + messages = messages[:MAX_EXAMPLES] + + tasks = [llm(message, history=history, json=json_mode) for message in messages] + + responses = await asyncio.gather(*tasks) + + return [ + json.dumps(response.json or "") if json_mode else str(response.output) + for response in responses + ] diff --git a/func-app/graphrag/prompt_tune/generator/entity_summarization_prompt.py b/func-app/graphrag/prompt_tune/generator/entity_summarization_prompt.py new file mode 100644 index 0000000000..4ae5af77ec --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/entity_summarization_prompt.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Entity summarization prompt generation module.""" + +from pathlib import Path + +from graphrag.prompt_tune.template import ENTITY_SUMMARIZATION_PROMPT + +ENTITY_SUMMARIZATION_FILENAME = "summarize_descriptions.txt" + + +def create_entity_summarization_prompt( + persona: str, + language: str, + output_path: Path | None = None, +) -> str: + """Create a prompt for entity summarization. If output_path is provided, write the prompt to a file. + + Parameters + ---------- + - persona (str): The persona to use for the entity summarization prompt + - language (str): The language to use for the entity summarization prompt + - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + """ + prompt = ENTITY_SUMMARIZATION_PROMPT.format(persona=persona, language=language) + + if output_path: + output_path.mkdir(parents=True, exist_ok=True) + + output_path = output_path / ENTITY_SUMMARIZATION_FILENAME + # Write file to output path + with output_path.open("wb") as file: + file.write(prompt.encode(encoding="utf-8", errors="strict")) + + return prompt diff --git a/func-app/graphrag/prompt_tune/generator/entity_types.py b/func-app/graphrag/prompt_tune/generator/entity_types.py new file mode 100644 index 0000000000..42518acd8c --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/entity_types.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Entity type generation module for fine-tuning.""" + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.generator.defaults import DEFAULT_TASK +from graphrag.prompt_tune.prompt.entity_types import ( + ENTITY_TYPE_GENERATION_JSON_PROMPT, + ENTITY_TYPE_GENERATION_PROMPT, +) + + +async def generate_entity_types( + llm: CompletionLLM, + domain: str, + persona: str, + docs: str | list[str], + task: str = DEFAULT_TASK, + json_mode: bool = False, +) -> str | list[str]: + """ + Generate entity type categories from a given set of (small) documents. + + Example Output: + "entity_types": ['military unit', 'organization', 'person', 'location', 'event', 'date', 'equipment'] + """ + formatted_task = task.format(domain=domain) + + docs_str = "\n".join(docs) if isinstance(docs, list) else docs + + entity_types_prompt = ( + ENTITY_TYPE_GENERATION_JSON_PROMPT + if json_mode + else ENTITY_TYPE_GENERATION_PROMPT + ).format(task=formatted_task, input_text=docs_str) + + history = [{"role": "system", "content": persona}] + + response = await llm(entity_types_prompt, history=history, json=json_mode) + + if json_mode: + return (response.json or {}).get("entity_types", []) + + return str(response.output) diff --git a/func-app/graphrag/prompt_tune/generator/language.py b/func-app/graphrag/prompt_tune/generator/language.py new file mode 100644 index 0000000000..38de531ca3 --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/language.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Language detection for GraphRAG prompts.""" + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.prompt import DETECT_LANGUAGE_PROMPT + + +async def detect_language(llm: CompletionLLM, docs: str | list[str]) -> str: + """Detect input language to use for GraphRAG prompts. + + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - docs (str | list[str]): The docs to detect language from + + Returns + ------- + - str: The detected language. + """ + docs_str = " ".join(docs) if isinstance(docs, list) else docs + language_prompt = DETECT_LANGUAGE_PROMPT.format(input_text=docs_str) + + response = await llm(language_prompt) + + return str(response.output) diff --git a/func-app/graphrag/prompt_tune/generator/persona.py b/func-app/graphrag/prompt_tune/generator/persona.py new file mode 100644 index 0000000000..cdd57a655d --- /dev/null +++ b/func-app/graphrag/prompt_tune/generator/persona.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Persona generating module for fine-tuning GraphRAG prompts.""" + +from graphrag.llm.types.llm_types import CompletionLLM +from graphrag.prompt_tune.generator.defaults import DEFAULT_TASK +from graphrag.prompt_tune.prompt import GENERATE_PERSONA_PROMPT + + +async def generate_persona( + llm: CompletionLLM, domain: str, task: str = DEFAULT_TASK +) -> str: + """Generate an LLM persona to use for GraphRAG prompts. + + Parameters + ---------- + - llm (CompletionLLM): The LLM to use for generation + - domain (str): The domain to generate a persona for + - task (str): The task to generate a persona for. Default is DEFAULT_TASK + """ + formatted_task = task.format(domain=domain) + persona_prompt = GENERATE_PERSONA_PROMPT.format(sample_task=formatted_task) + + response = await llm(persona_prompt) + + return str(response.output) diff --git a/func-app/graphrag/prompt_tune/loader/__init__.py b/func-app/graphrag/prompt_tune/loader/__init__.py new file mode 100644 index 0000000000..94e64cbe87 --- /dev/null +++ b/func-app/graphrag/prompt_tune/loader/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning config and data loader module.""" + +from .config import read_config_parameters +from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks + +__all__ = [ + "MIN_CHUNK_OVERLAP", + "MIN_CHUNK_SIZE", + "load_docs_in_chunks", + "read_config_parameters", +] diff --git a/func-app/graphrag/prompt_tune/loader/config.py b/func-app/graphrag/prompt_tune/loader/config.py new file mode 100644 index 0000000000..8994604f92 --- /dev/null +++ b/func-app/graphrag/prompt_tune/loader/config.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Config loading, parsing and handling module.""" + +from pathlib import Path + +from graphrag.config import create_graphrag_config +from graphrag.index.progress.types import ProgressReporter + + +def read_config_parameters(root: str, reporter: ProgressReporter): + """Read the configuration parameters from the settings file or environment variables. + + Parameters + ---------- + - root: The root directory where the parameters are. + - reporter: The progress reporter. + """ + _root = Path(root) + settings_yaml = _root / "settings.yaml" + if not settings_yaml.exists(): + settings_yaml = _root / "settings.yml" + settings_json = _root / "settings.json" + + if settings_yaml.exists(): + reporter.info(f"Reading settings from {settings_yaml}") + with settings_yaml.open("rb") as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + if settings_json.exists(): + reporter.info(f"Reading settings from {settings_json}") + with settings_json.open("rb") as file: + import json + + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + reporter.info("Reading settings from environment variables") + return create_graphrag_config(root_dir=root) diff --git a/func-app/graphrag/prompt_tune/loader/input.py b/func-app/graphrag/prompt_tune/loader/input.py new file mode 100644 index 0000000000..86c4a76040 --- /dev/null +++ b/func-app/graphrag/prompt_tune/loader/input.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Input loading module.""" + +from typing import cast + +import numpy as np +import pandas as pd +from datashaper import NoopVerbCallbacks, TableContainer, VerbInput + +import graphrag.config.defaults as defs +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.input import load_input +from graphrag.index.llm import load_llm_embeddings +from graphrag.index.progress.types import ProgressReporter +from graphrag.index.verbs import chunk +from graphrag.llm.types.llm_types import EmbeddingLLM + +MIN_CHUNK_OVERLAP = 0 +MIN_CHUNK_SIZE = 200 +N_SUBSET_MAX = 300 +K = 15 + + +async def _embed_chunks( + text_chunks: pd.DataFrame, + embedding_llm: EmbeddingLLM, + n_subset_max: int = N_SUBSET_MAX, +) -> tuple[pd.DataFrame, np.ndarray]: + """Convert text chunks into dense text embeddings.""" + sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks))) + embeddings = await embedding_llm(sampled_text_chunks["chunks"].tolist()) + return text_chunks, np.array(embeddings.output) + + +def _sample_chunks_from_embeddings( + text_chunks: pd.DataFrame, + embeddings, + k: int = K, +) -> pd.DataFrame: + """Sample text chunks from embeddings.""" + center = np.mean(embeddings, axis=0) + distances = np.linalg.norm(embeddings - center, axis=1) + nearest_indices = np.argsort(distances)[:k] + + return text_chunks.iloc[nearest_indices] + + +async def load_docs_in_chunks( + root: str, + config: GraphRagConfig, + select_method: str, + limit: int, + reporter: ProgressReporter, + chunk_size: int = MIN_CHUNK_SIZE, + n_subset_max: int = N_SUBSET_MAX, + k: int = K, +) -> list[str]: + """Load docs into chunks for generating prompts.""" + dataset = await load_input(config.input, reporter, root) + + # covert to text units + input = VerbInput(input=TableContainer(table=dataset)) + chunk_strategy = config.chunks.resolved_strategy(defs.ENCODING_MODEL) + + # Use smaller chunks, to avoid huge prompts + chunk_strategy["chunk_size"] = chunk_size + chunk_strategy["chunk_overlap"] = MIN_CHUNK_OVERLAP + + dataset_chunks_table_container = chunk( + input, + column="text", + to="chunks", + callbacks=NoopVerbCallbacks(), + strategy=chunk_strategy, + ) + + dataset_chunks = cast(pd.DataFrame, dataset_chunks_table_container.table) + + # Select chunks into a new df and explode it + chunks_df = pd.DataFrame(dataset_chunks["chunks"].explode()) # type: ignore + + # Depending on the select method, build the dataset + if limit <= 0 or limit > len(chunks_df): + limit = len(chunks_df) + + if select_method == "top": + chunks_df = chunks_df[:limit] + elif select_method == "random": + chunks_df = chunks_df.sample(n=limit) + elif select_method == "auto": + if k is None or k <= 0: + msg = "k must be an integer > 0" + raise ValueError(msg) + embedding_llm = load_llm_embeddings( + name="prompt_tuning_embeddings", + llm_type=config.embeddings.resolved_strategy()["llm"]["type"], + callbacks=NoopVerbCallbacks(), + cache=None, + llm_config=config.embeddings.resolved_strategy()["llm"], + ) + + chunks_df, embeddings = await _embed_chunks( + chunks_df, embedding_llm, n_subset_max=n_subset_max + ) + chunks_df = _sample_chunks_from_embeddings(chunks_df, embeddings, k=k) + + # Convert the dataset to list form, so we have a list of documents + return chunks_df["chunks"].tolist() diff --git a/func-app/graphrag/prompt_tune/prompt/__init__.py b/func-app/graphrag/prompt_tune/prompt/__init__.py new file mode 100644 index 0000000000..991d52856e --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/__init__.py @@ -0,0 +1,32 @@ +"""Persona, entity type, relationships and domain generation prompts module.""" + +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from .community_report_rating import GENERATE_REPORT_RATING_PROMPT +from .community_reporter_role import GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT +from .domain import GENERATE_DOMAIN_PROMPT +from .entity_relationship import ( + ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT, + ENTITY_RELATIONSHIPS_GENERATION_PROMPT, + UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT, +) +from .entity_types import ( + ENTITY_TYPE_GENERATION_JSON_PROMPT, + ENTITY_TYPE_GENERATION_PROMPT, +) +from .language import DETECT_LANGUAGE_PROMPT +from .persona import GENERATE_PERSONA_PROMPT + +__all__ = [ + "DETECT_LANGUAGE_PROMPT", + "ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT", + "ENTITY_RELATIONSHIPS_GENERATION_PROMPT", + "ENTITY_TYPE_GENERATION_JSON_PROMPT", + "ENTITY_TYPE_GENERATION_PROMPT", + "GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT", + "GENERATE_DOMAIN_PROMPT", + "GENERATE_PERSONA_PROMPT", + "GENERATE_REPORT_RATING_PROMPT", + "UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT", +] diff --git a/func-app/graphrag/prompt_tune/prompt/community_report_rating.py b/func-app/graphrag/prompt_tune/prompt/community_report_rating.py new file mode 100644 index 0000000000..b061645b94 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/community_report_rating.py @@ -0,0 +1,132 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine tuning prompts for Community Reports Rating.""" + +GENERATE_REPORT_RATING_PROMPT = """ + +You are a helpful agent tasked with rating the importance of a given text in the context of the provided domain and persona. Your goal is to provide a rating that reflects the relevance and significance of the text to the specified domain and persona. Use your expertise to evaluate the text based on the importance criteria and assign a float score between 0-10. Only respond with the text description of the importance criteria. Use the provided example data format to guide your response. Ignore the content of the example data and focus on the structure. + +###################### +-Examples- +###################### + +### Example 1 + +# Domain + +Personal and Family Communication + +# Persona + +You are an expert in Social Network Analysis with a focus on the Personal and Family Communication domain. You are skilled at mapping and interpreting complex social networks, understanding the dynamics of interpersonal relationships, and identifying patterns of communication within communities. You are adept at helping people understand the structure and relations within their personal and family networks, providing insights into how information flows, how strong various connections are, and how these networks influence individual and group behavior. + +# Data + + +Subject: Re: Event +From: Alice Brown alice.brown@example.com +Date: 2012-11-14, 9:52 a.m. +To: John Smith john.smith@example.com +CC: Jane Doe jane.doe@example.com, Bob Johnson bob.johnson@example.com, Emma Davis emma.davis@example.com + +The event is at 6pm at City Hall (Queen street) event chamber. We +just need to get there by 5:45pm. It is 30-minute long so we will be +done by 6:30pm. We'll then head over to New Sky on Spadina for some +unique cuisine! + +Guests are you and Emma, and my uncle and auntie from London +who my folks have designated to act as their reps. Jane and Joe are +witnesses. + +Be there or be square! +Alice + +On Wed, Nov 14, 2012 at 9:40 AM, John Smith john.smith@example.com wrote: + +Thats the day after Bob's event! +Any more details on the event schedule? ITS NEXT WEEK! +On Tue, Nov 13, 2012 at 7:51 PM, Jane Doe +jane.doe@example.com wrote: +I am supposed to forward you the invitation to this year's celebration. +Date: Saturday, Nov. 24, 6 pm starting +Place as usual: Dean's house, 6 Cardish, Kleinburg L0J 1C0 +Jane Doe +jane.doe@example.com + +# Importance Criteria + +A float score between 0-10 that represents the relevance of the email's content to family communication, health concerns, travel plans, and interpersonal dynamics, with 1 being trivial or spam and 10 being highly relevant, urgent, and impactful to family cohesion or well-being. +############################# + +### Example 2 + +# Domain + +Literary Analysis + +# Persona + +You are a literary scholar with a focus on works from the 19th century. You are skilled at analyzing and interpreting texts, identifying themes and motifs, and understanding the historical and cultural contexts in which these works were written. You are adept at helping people understand the deeper meanings and significance of literary works, providing insights into the author's intentions, the social issues addressed in the text, and the impact of these works on contemporary society. + +# Data + +Had she found Jane in any apparent danger, Mrs. Bennet would have been very miserable; but being satisfied on seeing her that her illness was not alarming, she had no wish of her recovering immediately, as her restoration to health would probably remove her from Netherfield. She would not listen, therefore, to her daughter's proposal of being carried home; neither did the apothecary, who arrived about the same time, think it at all advisable. After sitting a little with Jane, on Miss Bingley's appearance and invitation, the mother and three daughters all attended her into the breakfast parlor. Bingley met them with hopes that Mrs. Bennet had not found Miss Bennet worse than she expected. + +"Indeed I have, Sir," was her answer. "She is a great deal too ill to be moved. Mr. Jones says we must not think of moving her. We must trespass a little longer on your kindness." + +"Removed!" cried Bingley. "It must not be thought of. My sister, I am sure, will not hear of her removal." + +# Importance Criteria + +A float score between 0-10 that represents the relevance of the text to literary analysis, historical context, thematic interpretation, and cultural significance, with 1 being trivial or irrelevant and 10 being highly significant, profound, and impactful to the understanding of the text and its implications. +############################# + +### Example 3 + +# Domain + +Environmental Science + +# Persona + +You are an environmental scientist with a focus on climate change and sustainability. You are skilled at analyzing data, interpreting social commentary and recommending policy changes. You are adept at helping people understand the causes and consequences of climate change, providing insights into how they can reduce their carbon footprint, adopt sustainable practices, and contribute to a healthier planet. + +# Data + +Host 1 (Anna): Welcome to "Green Living Today," the podcast where we explore practical tips and inspiring stories about sustainable living. I'm your host, Anna Green. + +Host 2 (Mark): And I'm Mark Smith. Today, we have a special episode focused on reducing plastic waste in our daily lives. We'll be talking to a special guest who has made significant strides in living a plastic-free lifestyle. + +Anna: That's right, Mark. Our guest today is Laura Thompson, the founder of "Plastic-Free Living," a blog dedicated to sharing tips and resources for reducing plastic use. Welcome to the show, Laura! + +Guest (Laura): Thanks, Anna and Mark. It's great to be here. + +Mark: Laura, let's start by talking about your journey. What inspired you to start living a plastic-free lifestyle? + +# Importance Criteria + +A float score between 0-10 that represents the relevance of the text to sustainability, plastic waste reduction, and environmental policies, with 1 being trivial or irrelevant and 10 being highly significant, impactful, and actionable in promoting environmental awareness. +############################# + + +############################# +-Real Data- +############################# + +# Domain + +{domain} + +# Persona + +{persona} + +# Data + +{input_text} + +# Importance Criteria + + +""" diff --git a/func-app/graphrag/prompt_tune/prompt/community_reporter_role.py b/func-app/graphrag/prompt_tune/prompt/community_reporter_role.py new file mode 100644 index 0000000000..b667bc2940 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/community_reporter_role.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for community reporter role generation.""" + +GENERATE_COMMUNITY_REPORTER_ROLE_PROMPT = """ +{persona} +Given a sample text, help the user by creating a role definition that will be tasked with community analysis. +Take a look at this example, determine its key parts, and using the domain provided and your expertise, create a new role definition for the provided inputs that follows the same pattern as the example. +Remember, your output should look just like the provided example in structure and content. + +Example: +A technologist reporter that is analyzing Kevin Scott's "Behind the Tech Podcast", given a list of entities +that belong to the community as well as their relationships and optional associated claims. +The report will be used to inform decision-makers about significant developments associated with the community and their potential impact. + + +Domain: {domain} +Text: {input_text} +Role:""" diff --git a/func-app/graphrag/prompt_tune/prompt/domain.py b/func-app/graphrag/prompt_tune/prompt/domain.py new file mode 100644 index 0000000000..4b4587f8d8 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/domain.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for domain generation.""" + +GENERATE_DOMAIN_PROMPT = """ +You are an intelligent assistant that helps a human to analyze the information in a text document. +Given a sample text, help the user by assigning a descriptive domain that summarizes what the text is about. +Example domains are: "Social studies", "Algorithmic analysis", "Medical science", among others. + +Text: {input_text} +Domain:""" diff --git a/func-app/graphrag/prompt_tune/prompt/entity_relationship.py b/func-app/graphrag/prompt_tune/prompt/entity_relationship.py new file mode 100644 index 0000000000..3af77db641 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/entity_relationship.py @@ -0,0 +1,355 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for entity relationship generation.""" + +ENTITY_RELATIONSHIPS_GENERATION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity +Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +5. When finished, output {{completion_delimiter}}. + +###################### +-Examples- +###################### +Example 1: +Entity_types: ORGANIZATION,PERSON +Text: +The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. +###################### +Output: +("entity"{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Martin Smith is the chair of the Central Institution) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MARKET STRATEGY COMMITTEE{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{{tuple_delimiter}}9) +{{completion_delimiter}} + +###################### +Example 2: +Entity_types: ORGANIZATION +Text: +TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform. + +TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. +###################### +Output: +("entity"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) +{{record_delimiter}} +("entity"{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}Vision Holdings is a firm that previously owned TechGlobal) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}Vision Holdings formerly owned TechGlobal from 2014 until present{{tuple_delimiter}}5) +{{completion_delimiter}} + +###################### +Example 3: +Entity_types: ORGANIZATION,GEO,PERSON +Text: +Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia. + +The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara. + +The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara. + +They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion. + +The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. +###################### +Output: +("entity"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}GEO{{tuple_delimiter}}Firuzabad held Aurelians as hostages) +{{record_delimiter}} +("entity"{{tuple_delimiter}}AURELIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country seeking to release hostages) +{{record_delimiter}} +("entity"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country that negotiated a swap of money in exchange for hostages) +{{record_delimiter}} +{{record_delimiter}} +("entity"{{tuple_delimiter}}TIRUZIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital of Firuzabad where the Aurelians were being held) +{{record_delimiter}} +("entity"{{tuple_delimiter}}KROHAARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Quintara) +{{record_delimiter}} +("entity"{{tuple_delimiter}}CASHION{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Aurelia) +{{record_delimiter}} +("entity"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian who spent time in Tiruzia's Alhamia Prison) +{{record_delimiter}} +("entity"{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}GEO{{tuple_delimiter}}Prison in Tiruzia) +{{record_delimiter}} +("entity"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian journalist who was held hostage) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Bratinas national and environmentalist who was held hostage) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Firuzabad negotiated a hostage exchange with Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}Samuel Namara was a prisoner at Alhamia prison{{tuple_delimiter}}8) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Samuel Namara was a hostage in Firuzabad{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Meggie Tazbah was a hostage in Firuzabad{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Durke Bataglani was a hostage in Firuzabad{{tuple_delimiter}}2) +{{completion_delimiter}} + +-Real Data- +###################### +entity_types: {entity_types} +text: {input_text} +###################### +output: +""" + +ENTITY_RELATIONSHIPS_GENERATION_JSON_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities + +Format each entity output as a JSON entry with the following format: + +{{"name": , "type": , "description": }} + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity + +Format each relationship as a JSON entry with the following format: + +{{"source": , "target": , "relationship": , "relationship_strength": }} + +3. Return output in {language} as a single list of all JSON entities and relationships identified in steps 1 and 2. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +###################### +-Examples- +###################### +Example 1: +Text: +The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. +###################### +Output: +[ + {{"name": "CENTRAL INSTITUTION", "type": "ORGANIZATION", "description": "The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday"}}, + {{"name": "MARTIN SMITH", "type": "PERSON", "description": "Martin Smith is the chair of the Central Institution"}}, + {{"name": "MARKET STRATEGY COMMITTEE", "type": "ORGANIZATION", "description": "The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply"}}, + {{"source": "MARTIN SMITH", "target": "CENTRAL INSTITUTION", "relationship": "Martin Smith is the Chair of the Central Institution and will answer questions at a press conference", "relationship_strength": 9}} +] + +###################### +Example 2: +Text: +TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform. + +TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. +###################### +Output: +[ + {{"name": "TECHGLOBAL", "type": "ORGANIZATION", "description": "TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones"}}, + {{"name": "VISION HOLDINGS", "type": "ORGANIZATION", "description": "Vision Holdings is a firm that previously owned TechGlobal"}}, + {{"source": "TECHGLOBAL", "target": "VISION HOLDINGS", "relationship": "Vision Holdings formerly owned TechGlobal from 2014 until present", "relationship_strength": 5}} +] + +###################### +Example 3: +Text: +Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia. + +The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara. + +The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara. + +They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion. + +The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. +###################### +Output: +[ + {{"name": "FIRUZABAD", "type": "GEO", "description": "Firuzabad held Aurelians as hostages"}}, + {{"name": "AURELIA", "type": "GEO", "description": "Country seeking to release hostages"}}, + {{"name": "QUINTARA", "type": "GEO", "description": "Country that negotiated a swap of money in exchange for hostages"}}, + {{"name": "TIRUZIA", "type": "GEO", "description": "Capital of Firuzabad where the Aurelians were being held"}}, + {{"name": "KROHAARA", "type": "GEO", "description": "Capital city in Quintara"}}, + {{"name": "CASHION", "type": "GEO", "description": "Capital city in Aurelia"}}, + {{"name": "SAMUEL NAMARA", "type": "PERSON", "description": "Aurelian who spent time in Tiruzia's Alhamia Prison"}}, + {{"name": "ALHAMIA PRISON", "type": "GEO", "description": "Prison in Tiruzia"}}, + {{"name": "DURKE BATAGLANI", "type": "PERSON", "description": "Aurelian journalist who was held hostage"}}, + {{"name": "MEGGIE TAZBAH", "type": "PERSON", "description": "Bratinas national and environmentalist who was held hostage"}}, + {{"source": "FIRUZABAD", "target": "AURELIA", "relationship": "Firuzabad negotiated a hostage exchange with Aurelia", "relationship_strength": 2}}, + {{"source": "QUINTARA", "target": "AURELIA", "relationship": "Quintara brokered the hostage exchange between Firuzabad and Aurelia", "relationship_strength": 2}}, + {{"source": "QUINTARA", "target": "FIRUZABAD", "relationship": "Quintara brokered the hostage exchange between Firuzabad and Aurelia", "relationship_strength": 2}}, + {{"source": "SAMUEL NAMARA", "target": "ALHAMIA PRISON", "relationship": "Samuel Namara was a prisoner at Alhamia prison", "relationship_strength": 8}}, + {{"source": "SAMUEL NAMARA", "target": "MEGGIE TAZBAH", "relationship": "Samuel Namara and Meggie Tazbah were exchanged in the same hostage release", "relationship_strength": 2}}, + {{"source": "SAMUEL NAMARA", "target": "DURKE BATAGLANI", "relationship": "Samuel Namara and Durke Bataglani were exchanged in the same hostage release", "relationship_strength": 2}}, + {{"source": "MEGGIE TAZBAH", "target": "DURKE BATAGLANI", "relationship": "Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release", "relationship_strength": 2}}, + {{"source": "SAMUEL NAMARA", "target": "FIRUZABAD", "relationship": "Samuel Namara was a hostage in Firuzabad", "relationship_strength": 2}}, + {{"source": "MEGGIE TAZBAH", "target": "FIRUZABAD", "relationship": "Meggie Tazbah was a hostage in Firuzabad", "relationship_strength": 2}}, + {{"source": "DURKE BATAGLANI", "target": "FIRUZABAD", "relationship": "Durke Bataglani was a hostage in Firuzabad", "relationship_strength": 2}} +] + + + +-Real Data- +###################### +entity_types: {entity_types} +text: {input_text} +###################### +output: +""" + +UNTYPED_ENTITY_RELATIONSHIPS_GENERATION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity, first identify all entities needed from the text in order to capture the information and ideas in the text. +Next, report all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: Suggest several labels or categories for the entity. The categories should not be specific, but should be as general as possible. +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity +Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +5. When finished, output {{completion_delimiter}}. + +###################### +-Examples- +###################### +Example 1: +Text: +The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. +###################### +Output: +("entity"{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Martin Smith is the chair of the Central Institution) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MARKET STRATEGY COMMITTEE{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{{tuple_delimiter}}9) +{{completion_delimiter}} + +###################### +Example 2: +Text: +TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform. + +TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. +###################### +Output: +("entity"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) +{{record_delimiter}} +("entity"{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}Vision Holdings is a firm that previously owned TechGlobal) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}Vision Holdings formerly owned TechGlobal from 2014 until present{{tuple_delimiter}}5) +{{completion_delimiter}} + +###################### +Example 3: +Text: +Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia. + +The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara. + +The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara. + +They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion. + +The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. +###################### +Output: +("entity"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}GEO{{tuple_delimiter}}Firuzabad held Aurelians as hostages) +{{record_delimiter}} +("entity"{{tuple_delimiter}}AURELIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country seeking to release hostages) +{{record_delimiter}} +("entity"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country that negotiated a swap of money in exchange for hostages) +{{record_delimiter}} +{{record_delimiter}} +("entity"{{tuple_delimiter}}TIRUZIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital of Firuzabad where the Aurelians were being held) +{{record_delimiter}} +("entity"{{tuple_delimiter}}KROHAARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Quintara) +{{record_delimiter}} +("entity"{{tuple_delimiter}}CASHION{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Aurelia) +{{record_delimiter}} +("entity"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian who spent time in Tiruzia's Alhamia Prison) +{{record_delimiter}} +("entity"{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}GEO{{tuple_delimiter}}Prison in Tiruzia) +{{record_delimiter}} +("entity"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian journalist who was held hostage) +{{record_delimiter}} +("entity"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Bratinas national and environmentalist who was held hostage) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Firuzabad negotiated a hostage exchange with Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}Samuel Namara was a prisoner at Alhamia prison{{tuple_delimiter}}8) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Samuel Namara was a hostage in Firuzabad{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Meggie Tazbah was a hostage in Firuzabad{{tuple_delimiter}}2) +{{record_delimiter}} +("relationship"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Durke Bataglani was a hostage in Firuzabad{{tuple_delimiter}}2) +{{completion_delimiter}} + +###################### +-Real Data- +###################### +Text: {input_text} +###################### +Output: +""" diff --git a/func-app/graphrag/prompt_tune/prompt/entity_types.py b/func-app/graphrag/prompt_tune/prompt/entity_types.py new file mode 100644 index 0000000000..99b21db645 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/entity_types.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for entity types generation.""" + +ENTITY_TYPE_GENERATION_PROMPT = """ +The goal is to study the connections and relations between the entity types and their features in order to understand all available information from the text. +The user's task is to {task}. +As part of the analysis, you want to identify the entity types present in the following text. +The entity types must be relevant to the user task. +Avoid general entity types such as "other" or "unknown". +This is VERY IMPORTANT: Do not generate redundant or overlapping entity types. For example, if the text contains "company" and "organization" entity types, you should return only one of them. +Don't worry about quantity, always choose quality over quantity. And make sure EVERYTHING in your answer is relevant to the context of entity extraction. +And remember, it is ENTITY TYPES what we need. +Return the entity types in as a list of comma sepparated of strings. +===================================================================== +EXAMPLE SECTION: The following section includes example output. These examples **must be excluded from your answer**. + +EXAMPLE 1 +Task: Determine the connections and organizational hierarchy within the specified community. +Text: Example_Org_A is a company in Sweden. Example_Org_A's director is Example_Individual_B. +RESPONSE: +organization, person +END OF EXAMPLE 1 + +EXAMPLE 2 +Task: Identify the key concepts, principles, and arguments shared among different philosophical schools of thought, and trace the historical or ideological influences they have on each other. +Text: Rationalism, epitomized by thinkers such as René Descartes, holds that reason is the primary source of knowledge. Key concepts within this school include the emphasis on the deductive method of reasoning. +RESPONSE: +concept, person, school of thought +END OF EXAMPLE 2 + +EXAMPLE 3 +Task: Identify the full range of basic forces, factors, and trends that would indirectly shape an issue. +Text: Industry leaders such as Panasonic are vying for supremacy in the battery production sector. They are investing heavily in research and development and are exploring new technologies to gain a competitive edge. +RESPONSE: +organization, technology, sectors, investment strategies +END OF EXAMPLE 3 +====================================================================== + +====================================================================== +REAL DATA: The following section is the real data. You should use only this real data to prepare your answer. Generate Entity Types only. +Task: {task} +Text: {input_text} +RESPONSE: +{{}} +""" + +ENTITY_TYPE_GENERATION_JSON_PROMPT = """ +The goal is to study the connections and relations between the entity types and their features in order to understand all available information from the text. +The user's task is to {task}. +As part of the analysis, you want to identify the entity types present in the following text. +The entity types must be relevant to the user task. +Avoid general entity types such as "other" or "unknown". +This is VERY IMPORTANT: Do not generate redundant or overlapping entity types. For example, if the text contains "company" and "organization" entity types, you should return only one of them. +Don't worry about quantity, always choose quality over quantity. And make sure EVERYTHING in your answer is relevant to the context of entity extraction. +Return the entity types in JSON format with "entities" as the key and the entity types as an array of strings. +===================================================================== +EXAMPLE SECTION: The following section includes example output. These examples **must be excluded from your answer**. + +EXAMPLE 1 +Task: Determine the connections and organizational hierarchy within the specified community. +Text: Example_Org_A is a company in Sweden. Example_Org_A's director is Example_Individual_B. +JSON RESPONSE: +{{"entity_types": [organization, person] }} +END OF EXAMPLE 1 + +EXAMPLE 2 +Task: Identify the key concepts, principles, and arguments shared among different philosophical schools of thought, and trace the historical or ideological influences they have on each other. +Text: Rationalism, epitomized by thinkers such as René Descartes, holds that reason is the primary source of knowledge. Key concepts within this school include the emphasis on the deductive method of reasoning. +JSON RESPONSE: +{{"entity_types": [concept, person, school of thought] }} +END OF EXAMPLE 2 + +EXAMPLE 3 +Task: Identify the full range of basic forces, factors, and trends that would indirectly shape an issue. +Text: Industry leaders such as Panasonic are vying for supremacy in the battery production sector. They are investing heavily in research and development and are exploring new technologies to gain a competitive edge. +JSON RESPONSE: +{{"entity_types": [organization, technology, sectors, investment strategies] }} +END OF EXAMPLE 3 +====================================================================== + +====================================================================== +REAL DATA: The following section is the real data. You should use only this real data to prepare your answer. Generate Entity Types only. +Task: {task} +Text: {input_text} +JSON response: +{{"entity_types": [] }} +""" diff --git a/func-app/graphrag/prompt_tune/prompt/language.py b/func-app/graphrag/prompt_tune/prompt/language.py new file mode 100644 index 0000000000..68fd04029f --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/language.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for language detection.""" + +DETECT_LANGUAGE_PROMPT = """ +You are an intelligent assistant that helps a human to analyze the information in a text document. +Given a sample text, help the user by determining what's the primary language of the provided texts. +Examples are: "English", "Spanish", "Japanese", "Portuguese" among others. + +Text: {input_text} +Language:""" diff --git a/func-app/graphrag/prompt_tune/prompt/persona.py b/func-app/graphrag/prompt_tune/prompt/persona.py new file mode 100644 index 0000000000..58515fd204 --- /dev/null +++ b/func-app/graphrag/prompt_tune/prompt/persona.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for persona generation.""" + +GENERATE_PERSONA_PROMPT = """ +You are an intelligent assistant that helps a human to analyze the information in a text document. +Given a specific type of task and sample text, help the user by generating a 3 to 4 sentence description of an expert who could help solve the problem. +Use a format similar to the following: +You are an expert {{role}}. You are skilled at {{relevant skills}}. You are adept at helping people with {{specific task}}. + +task: {sample_task} +persona description:""" diff --git a/func-app/graphrag/prompt_tune/template/__init__.py b/func-app/graphrag/prompt_tune/template/__init__.py new file mode 100644 index 0000000000..e056762ff7 --- /dev/null +++ b/func-app/graphrag/prompt_tune/template/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for entity extraction, entity summarization, and community report summarization.""" + +from .community_report_summarization import COMMUNITY_REPORT_SUMMARIZATION_PROMPT +from .entity_extraction import ( + EXAMPLE_EXTRACTION_TEMPLATE, + GRAPH_EXTRACTION_JSON_PROMPT, + GRAPH_EXTRACTION_PROMPT, + UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE, + UNTYPED_GRAPH_EXTRACTION_PROMPT, +) +from .entity_summarization import ENTITY_SUMMARIZATION_PROMPT + +__all__ = [ + "COMMUNITY_REPORT_SUMMARIZATION_PROMPT", + "ENTITY_SUMMARIZATION_PROMPT", + "EXAMPLE_EXTRACTION_TEMPLATE", + "GRAPH_EXTRACTION_JSON_PROMPT", + "GRAPH_EXTRACTION_PROMPT", + "UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE", + "UNTYPED_GRAPH_EXTRACTION_PROMPT", +] diff --git a/func-app/graphrag/prompt_tune/template/community_report_summarization.py b/func-app/graphrag/prompt_tune/template/community_report_summarization.py new file mode 100644 index 0000000000..14e039ba41 --- /dev/null +++ b/func-app/graphrag/prompt_tune/template/community_report_summarization.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for community report summarization.""" + +COMMUNITY_REPORT_SUMMARIZATION_PROMPT = """ +{persona} + +# Goal +Write a comprehensive assessment report of a community taking on the role of a {role}. The content of this report includes an overview of the community's key entities and relationships. + +# Report Structure +The report should include the following sections: +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant points associated with its entities. +- REPORT RATING: {report_rating_description} +- RATING EXPLANATION: Give a single sentence explanation of the rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format. Don't use any unnecessary escape sequences. The output should be a single JSON object that can be parsed by json.loads. + {{ + "title": "", + "summary": "", + "rating": , + "rating_explanation": "" + "findings": "[{{"summary":"", "explanation": "", "explanation": " (, ... ()]. If there are more than 10 data records, show the top 10 most relevant records. +Each paragraph should contain multiple sentences of explanation and concrete examples with specific named entities. All paragraphs must have these references at the start and end. Use "NONE" if there are no related roles or records. Everything should be in {language}. + +Example paragraph with references added: +This is a paragraph of the output text [records: Entities (1, 2, 3), Claims (2, 5), Relationships (10, 12)] + +# Example Input +----------- +Text: + +Entities + +id,entity,description +5,ABILA CITY PARK,Abila City Park is the location of the POK rally + +Relationships + +id,source,target,description +37,ABILA CITY PARK,POK RALLY,Abila City Park is the location of the POK rally +38,ABILA CITY PARK,POK,POK is holding a rally in Abila City Park +39,ABILA CITY PARK,POKRALLY,The POKRally is taking place at Abila City Park +40,ABILA CITY PARK,CENTRAL BULLETIN,Central Bulletin is reporting on the POK rally taking place in Abila City Park + +Output: +{{ + "title": "Abila City Park and POK Rally", + "summary": "The community revolves around the Abila City Park, which is the location of the POK rally. The park has relationships with POK, POKRALLY, and Central Bulletin, all +of which are associated with the rally event.", + "rating": 5.0, + "rating_explanation": "The impact rating is moderate due to the potential for unrest or conflict during the POK rally.", + "findings": [ + {{ + "summary": "Abila City Park as the central location", + "explanation": "Abila City Park is the central entity in this community, serving as the location for the POK rally. This park is the common link between all other +entities, suggesting its significance in the community. The park's association with the rally could potentially lead to issues such as public disorder or conflict, depending on the +nature of the rally and the reactions it provokes. [records: Entities (5), Relationships (37, 38, 39, 40)]" + }}, + {{ + "summary": "POK's role in the community", + "explanation": "POK is another key entity in this community, being the organizer of the rally at Abila City Park. The nature of POK and its rally could be a potential +source of threat, depending on their objectives and the reactions they provoke. The relationship between POK and the park is crucial in understanding the dynamics of this community. +[records: Relationships (38)]" + }}, + {{ + "summary": "POKRALLY as a significant event", + "explanation": "The POKRALLY is a significant event taking place at Abila City Park. This event is a key factor in the community's dynamics and could be a potential +source of threat, depending on the nature of the rally and the reactions it provokes. The relationship between the rally and the park is crucial in understanding the dynamics of this +community. [records: Relationships (39)]" + }}, + {{ + "summary": "Role of Central Bulletin", + "explanation": "Central Bulletin is reporting on the POK rally taking place in Abila City Park. This suggests that the event has attracted media attention, which could +amplify its impact on the community. The role of Central Bulletin could be significant in shaping public perception of the event and the entities involved. [records: Relationships +(40)]" + }} + ] + +}} + +# Real Data + +Use the following text for your answer. Do not make anything up in your answer. + +Text: +{{input_text}} +Output:""" diff --git a/func-app/graphrag/prompt_tune/template/entity_extraction.py b/func-app/graphrag/prompt_tune/template/entity_extraction.py new file mode 100644 index 0000000000..32d8756ec2 --- /dev/null +++ b/func-app/graphrag/prompt_tune/template/entity_extraction.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for entity extraction.""" + +GRAPH_EXTRACTION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity +Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +5. When finished, output {{completion_delimiter}}. + +-Examples- +###################### +{examples} + +-Real Data- +###################### +entity_types: [{entity_types}] +text: {{input_text}} +###################### +output:""" + +GRAPH_EXTRACTION_JSON_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity output as a JSON entry with the following format: + +{{"name": , "type": , "description": }} + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity +Format each relationship as a JSON entry with the following format: + +{{"source": , "target": , "relationship": , "relationship_strength": }} + +3. Return output in {language} as a single list of all JSON entities and relationships identified in steps 1 and 2. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +-Examples- +###################### +{examples} + +-Real Data- +###################### +entity_types: {entity_types} +text: {{input_text}} +###################### +output:""" + +EXAMPLE_EXTRACTION_TEMPLATE = """ +Example {n}: + +entity_types: [{entity_types}] +text: +{input_text} +------------------------ +output: +{output} +############################# + +""" + +UNTYPED_EXAMPLE_EXTRACTION_TEMPLATE = """ +Example {n}: + +text: +{input_text} +------------------------ +output: +{output} +############################# + +""" + + +UNTYPED_GRAPH_EXTRACTION_PROMPT = """ +-Goal- +Given a text document that is potentially relevant to this activity, first identify all entities needed from the text in order to capture the information and ideas in the text. +Next, report all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: Suggest several labels or categories for the entity. The categories should not be specific, but should be as general as possible. +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity +Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) + +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. + +4. If you have to translate into {language}, just translate the descriptions, nothing else! + +5. When finished, output {{completion_delimiter}}. + +-Examples- +###################### +{examples} + +-Real Data- +###################### +text: {{input_text}} +###################### +output: +""" diff --git a/func-app/graphrag/prompt_tune/template/entity_summarization.py b/func-app/graphrag/prompt_tune/template/entity_summarization.py new file mode 100644 index 0000000000..60294a291b --- /dev/null +++ b/func-app/graphrag/prompt_tune/template/entity_summarization.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Fine-tuning prompts for entity summarization.""" + +ENTITY_SUMMARIZATION_PROMPT = """ +{persona} +Using your expertise, you're asked to generate a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, concise description in {language}. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. +Make sure it is written in third person, and include the entity names so we the have full context. + +Enrich it as much as you can with relevant information from the nearby text, this is very important. + +If no answer is possible, or the description is empty, only convey information that is provided within the text. +####### +-Data- +Entities: {{entity_name}} +Description List: {{description_list}} +####### +Output:""" diff --git a/func-app/graphrag/query/__init__.py b/func-app/graphrag/query/__init__.py new file mode 100644 index 0000000000..58a557f8a2 --- /dev/null +++ b/func-app/graphrag/query/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG Orchestration Module.""" diff --git a/func-app/graphrag/query/__main__.py b/func-app/graphrag/query/__main__.py new file mode 100644 index 0000000000..e2e01d6fa6 --- /dev/null +++ b/func-app/graphrag/query/__main__.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Query Engine package root.""" + +import argparse +from enum import Enum + +from .cli import run_global_search, run_local_search + +INVALID_METHOD_ERROR = "Invalid method" + + +class SearchType(Enum): + """The type of search to run.""" + + LOCAL = "local" + GLOBAL = "global" + + def __str__(self): + """Return the string representation of the enum value.""" + return self.value + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--config", + help="The configuration yaml file to use when running the query", + required=False, + type=str, + ) + + parser.add_argument( + "--data", + help="The path with the output data from the pipeline", + required=False, + type=str, + ) + + parser.add_argument( + "--root", + help="The data project root. Default value: the current directory", + required=False, + default=".", + type=str, + ) + + parser.add_argument( + "--method", + help="The method to run, one of: local or global", + required=True, + type=SearchType, + choices=list(SearchType), + ) + + parser.add_argument( + "--community_level", + help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities", + type=int, + default=2, + ) + + parser.add_argument( + "--response_type", + help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report", + type=str, + default="Multiple Paragraphs", + ) + + parser.add_argument( + "--context_id", + help="Guid describing context in which the search should be performed", + type=str, + #default="00000000-0000-0000-0000-000000000000", + ) + + parser.add_argument( + "--optimized_search", + help="Runs optimized search and export artifacts", + type=bool, + default=False, + ) + + parser.add_argument( + "--use_kusto_community_reports", + help="If enabled community reports are attempted to be used in Kusto during query", + action="store_true", + ) + + parser.add_argument( + "--paths", + help="Different paths for the query", + type=int, + default=0, # Default to normal graphrag search + ) + + parser.add_argument( + "query", + nargs=1, + help="The query to run", + type=str, + ) + + args = parser.parse_args() + + match args.method: + case SearchType.LOCAL: + run_local_search( + args.config, + args.data, + args.root, + args.community_level, + args.response_type, + args.context_id, + args.query[0], + optimized_search=args.optimized_search, + use_kusto_community_reports=args.use_kusto_community_reports, + paths=args.paths, + ) + case SearchType.GLOBAL: + run_global_search( + args.config, + args.data, + args.root, + args.community_level, + args.response_type, + args.context_id, + args.query[0], + ) + case _: + raise ValueError(INVALID_METHOD_ERROR) diff --git a/func-app/graphrag/query/cli.py b/func-app/graphrag/query/cli.py new file mode 100644 index 0000000000..a8a06d69cd --- /dev/null +++ b/func-app/graphrag/query/cli.py @@ -0,0 +1,472 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Command line interface for the query module.""" + +import asyncio +import os +from pathlib import Path +from typing import cast +from io import BytesIO + +from datashaper import VerbCallbacks +from graphrag.common.progress.rich import RichProgressReporter +from graphrag.common.storage import PipelineStorage, BlobPipelineStorage, FilePipelineStorage +from graphrag.common.utils.context_utils import get_files_by_contextid +from graphrag.config.enums import StorageType +from azure.core.exceptions import ResourceNotFoundError + +import pandas as pd + +from graphrag.config import ( + create_graphrag_config, + GraphRagConfig, +) +from graphrag.common.progress import PrintProgressReporter +from graphrag.index.verbs.entities.extraction.strategies.graph_intelligence.run_graph_intelligence import run_gi +from graphrag.index.verbs.entities.extraction.strategies.typing import Document +from graphrag.model.entity import Entity +from graphrag.query.input.loaders.dfs import ( + store_entity_semantic_embeddings, +) +from graphrag.vector_stores import VectorStoreFactory, VectorStoreType +from graphrag.vector_stores.base import BaseVectorStore +from graphrag.vector_stores.lancedb import LanceDBVectorStore +from graphrag.vector_stores.kusto import KustoVectorStore +from .factories import get_global_search_engine, get_local_search_engine +from .indexer_adapters import ( + read_indexer_covariates, + read_indexer_entities, + read_indexer_relationships, + read_indexer_reports, + kt_read_indexer_reports, + read_indexer_text_units, +) + +from common.graph_db_client import GraphDBClient + +reporter = PrintProgressReporter("") + +reporter = PrintProgressReporter("") + +def __get_embedding_description_store( + entities: list[Entity] = [], + vector_store_type: str = VectorStoreType.LanceDB, + config_args: dict | None = None, + context_id: str = "", +): + """Get the embedding description store.""" + if not config_args: + config_args = {} + + collection_name = config_args.get( + "query_collection_name", "entity_description_embeddings" + ) + config_args.update({"collection_name": f"{collection_name}_{context_id}" if context_id else collection_name}) + vector_name = config_args.get( + "vector_search_column", "description_embedding" + ) + config_args.update({"vector_name": vector_name}) + config_args.update({"reports_name": f"reports_{context_id}" if context_id else "reports"}) + config_args.update({"text_units_name": f"text_units_{context_id}"}) + + description_embedding_store = VectorStoreFactory.get_vector_store( + vector_store_type=vector_store_type, kwargs=config_args + ) + + description_embedding_store.connect(**config_args) + + if vector_store_type == VectorStoreType.Kusto: + return description_embedding_store + + elif config_args.get("overwrite", True): + # this step assumps the embeddings where originally stored in a file rather + # than a vector database + + # dump embeddings from the entities list to the description_embedding_store + store_entity_semantic_embeddings( + entities=entities, vectorstore=description_embedding_store + ) + else: + # load description embeddings to an in-memory lancedb vectorstore + # to connect to a remote db, specify url and port values. + description_embedding_store = LanceDBVectorStore( + collection_name=collection_name + ) + description_embedding_store.connect( + db_uri=config_args.get("db_uri", "./lancedb") + ) + + # load data from an existing table + description_embedding_store.document_collection = ( + description_embedding_store.db_connection.open_table( + description_embedding_store.collection_name + ) + ) + + return description_embedding_store + + +def run_global_search( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, +): + """Run a global search with the given query.""" + data_dir, root_dir, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + if config.graphdb.enabled: + graph_db_client = GraphDBClient(config.graphdb) + data_path = Path(data_dir) + + final_nodes: pd.DataFrame = pd.read_parquet( + data_path / "create_final_nodes.parquet" + ) + if config.graphdb.enabled: + final_entities = graph_db_client.query_vertices() + else: + final_entities: pd.DataFrame = pd.read_parquet( + data_path / "create_final_entities.parquet" + ) + final_community_reports: pd.DataFrame = pd.read_parquet( + data_path / "create_final_community_reports.parquet" + ) + + reports = read_indexer_reports( + final_community_reports, final_nodes, community_level + ) + entities = read_indexer_entities(final_nodes, final_entities, community_level) + search_engine = get_global_search_engine( + config, + reports=reports, + entities=entities, + response_type=response_type, + ) + + result = search_engine.search(query=query) + + reporter.success(f"Global Search Response: {result.response}") + return result.response + +def path0( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + + """Run a local search with the given query.""" + data_dir, root_dir, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + + vector_store_args = ( + config.embeddings.vector_store if config.embeddings.vector_store else {} + ) + + reporter.info(f"Vector Store Args: {vector_store_args}") + vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) + + entities=[] + text_units=[] + covariates=[] + reports=[] + final_relationships=[] + + if(config.storage.type == StorageType.blob): + if(config.storage.container_name is not None): + output_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, + container_name=config.storage.container_name, + storage_account_blob_url=config.storage.storage_account_blob_url) + else: + ValueError("Storage type is Blob but container name is invalid") + elif(config.storage.type == StorageType.file): + output_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + + + + ##### LEGACY ####################### + + if vector_store_type == VectorStoreType.LanceDB: + # for the POC purpose input artifacts blob, output artifacts blob and input query blob storage are going to same. + if(config.storage.type == StorageType.memory): + ValueError("Memory storage is not supported") + if(config.storage.type == StorageType.blob): + if(config.storage.container_name is not None): + input_storage_client: PipelineStorage = BlobPipelineStorage(connection_string=config.storage.connection_string, + container_name=config.storage.container_name, + storage_account_blob_url=config.storage.storage_account_blob_url) + else: + ValueError("Storage type is Blob but container name is invalid") + if(config.storage.type == StorageType.file): + input_storage_client: PipelineStorage = FilePipelineStorage(config.root_dir) + + + data_paths = [] + data_paths = get_files_by_contextid(config, context_id) + final_nodes = pd.DataFrame() + final_community_reports = pd.DataFrame() + final_text_units = pd.DataFrame() + final_relationships = pd.DataFrame() + final_entities = pd.DataFrame() + final_covariates = pd.DataFrame() + + for data_path in data_paths: + #check from the config for the ouptut storage type and then read the data from the storage. + + #GraphDB: we may need to make change below to read nodes data from Graph DB + final_nodes = pd.concat([final_nodes, read_paraquet_file(input_storage_client, data_path + "/create_final_nodes.parquet")]) + final_community_reports = pd.concat([final_community_reports,read_paraquet_file(input_storage_client, data_path + "/create_final_community_reports.parquet")]) # KustoDB: Final_entities, Final_Nodes, Final_report should be merged and inserted to kusto + final_text_units = pd.concat([final_text_units, read_paraquet_file(input_storage_client, data_path + "/create_final_text_units.parquet")]) # lance db search need it for embedding mapping. we have embeddings in entities we should use from there. KustoDB already must have sorted it. + final_relationships = pd.concat([final_relationships,read_paraquet_file(input_storage_client, data_path + "/create_final_relationships.parquet")]) + + if not optimized_search: + final_covariates = pd.concat([final_covariates, read_paraquet_file(input_storage_client, data_path + "/create_final_covariates.parquet")]) + + final_entities = pd.concat([final_entities, read_paraquet_file(input_storage_client, data_path + "/create_final_entities.parquet")]) + + ############# End of for loop + + entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. + reports=read_indexer_reports( + final_community_reports, final_nodes, community_level + ) + + final_relationships=read_indexer_relationships(final_relationships) + + covariates = ( + read_indexer_covariates(final_covariates) + if final_covariates.empty is False + else [] + ) + text_units=read_indexer_text_units(final_text_units) + + + ######################################################################################## + + if use_kusto_community_reports: + raise ValueError("Using community reports is not supported.") + + description_embedding_store = __get_embedding_description_store( + entities=entities, + vector_store_type=vector_store_type, + config_args=vector_store_args, + context_id=context_id, + ) + + ''' + *** If KUSTO is enabled, both entities and final_relationships must be empty. + ''' + search_engine = get_local_search_engine( + config, + reports=reports, + text_units=text_units, + entities=entities, + relationships=final_relationships, + covariates={"claims": covariates}, + description_embedding_store=description_embedding_store, + response_type=response_type, + context_id=context_id, + is_optimized_search=optimized_search, + use_kusto_community_reports=use_kusto_community_reports, + graphdb_config=config.graphdb, + ) + + if optimized_search: + result = search_engine.optimized_search(query=query) + else: + result = search_engine.search(query=query) + for key in result.context_data.keys(): + asyncio.run(output_storage_client.set("query/output/"+ key +".paraquet", result.context_data[key].to_parquet())) #it shows as error in editor but not an error. + reporter.success(f"Local Search Response: {result.response}") + return result.response + +def path1( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + ValueError("Not implemented") + +def path2( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + """Path 2 + Find all the emails sent to trader by Tim Belden + a. Query -> LLM -> Entity Extracted -> 5 entities -> Set A [TimBelden1] + b. Query -> LLM -> Embeddings -> Y [x1..... Xn] + c. Run the query on Kusto for embedding Y [x1.....xn] for entitYid in [TimBelden1] + 4. Get the text units and get the response""" + data_dir, root_dir, config = _configure_paths_and_settings( + data_dir, root_dir, config_dir + ) + + # Populate args with dict of arguments for the LLM + args = {} + args['api_key'] = config.llm.api_key + args['type'] = config.llm.type + args['model'] = config.llm.model + args['model_supports_json'] = config.llm.model_supports_json + args['api_base'] = config.llm.api_base + args['api_version'] = config.llm.api_version + args['deployment_name'] = config.llm.deployment_name + llmm = {} + llmm['llm'] = args + + + result = asyncio.run(run_gi( + docs=[Document(text=query, id='0')], + entity_types=config.entity_extraction.entity_types, + reporter = None, + pipeline_cache=None, + args=llmm, + )) + + print(result.entities) + exit(0) + +def path3( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + ): + ValueError("Not implemented") + + +def run_local_search( + config_dir: str | None, + data_dir: str | None, + root_dir: str | None, + community_level: int, + response_type: str, + context_id: str, + query: str, + optimized_search: bool = False, + use_kusto_community_reports: bool = False, + paths: int = 0,): + """Run a local search with the given query.""" + if(paths==1): + return path1(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + elif(paths==2): + return path2(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + elif(paths==3): + return path3(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + return path0(config_dir, data_dir, root_dir, community_level, response_type, context_id, query, optimized_search, use_kusto_community_reports) + +def blob_exists(container_client, blob_name): + blob_client = container_client.get_blob_client(blob_name) + try: + # Attempt to get the blob properties + blob_client.get_blob_properties() + return True + except ResourceNotFoundError: + # Blob does not exist + return False + + +def read_paraquet_file(storage: PipelineStorage, path: str): + #create different enum for paraquet storage type + file_data = asyncio.run(storage.get(path, True)) + if file_data is None: + return pd.DataFrame() + return pd.read_parquet(BytesIO(file_data), engine="pyarrow") + +def _configure_paths_and_settings( + data_dir: str | None, + root_dir: str | None, + config_dir: str | None, +) -> tuple[str, str | None, GraphRagConfig]: + if data_dir is None and root_dir is None: + msg = "Either data_dir or root_dir must be provided." + raise ValueError(msg) + if data_dir is None: + data_dir = _infer_data_dir(cast(str, root_dir)) + config = _create_graphrag_config(root_dir, config_dir) + return data_dir, root_dir, config + + +def _infer_data_dir(root: str) -> str: + output = Path(root) / "output" + # use the latest data-run folder + if output.exists(): + folders = sorted(output.iterdir(), key=os.path.getmtime, reverse=True) + if len(folders) > 0: + folder = folders[0] + return str((folder / "artifacts").absolute()) + msg = f"Could not infer data directory from root={root}" + raise ValueError(msg) + + +def _create_graphrag_config( + root: str | None, + config_dir: str | None, +) -> GraphRagConfig: + """Create a GraphRag configuration.""" + return _read_config_parameters(root or "./", config_dir) + + +def _read_config_parameters(root: str, config: str | None): + _root = Path(root) + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) + if not settings_yaml.exists(): + settings_yaml = _root / "settings.yml" + + if settings_yaml.exists(): + reporter.info(f"Reading settings from {settings_yaml}") + with settings_yaml.open( + "rb", + ) as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) + if settings_json.exists(): + reporter.info(f"Reading settings from {settings_json}") + with settings_json.open("rb") as file: + import json + + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + + reporter.info("Reading settings from environment variables") + return create_graphrag_config(root_dir=root) diff --git a/func-app/graphrag/query/context_builder/__init__.py b/func-app/graphrag/query/context_builder/__init__.py new file mode 100644 index 0000000000..7e27364e1e --- /dev/null +++ b/func-app/graphrag/query/context_builder/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Functions to build context for system prompt to generate responses for a user query.""" diff --git a/func-app/graphrag/query/context_builder/builders.py b/func-app/graphrag/query/context_builder/builders.py new file mode 100644 index 0000000000..7a4ba277ae --- /dev/null +++ b/func-app/graphrag/query/context_builder/builders.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for global and local context builders.""" + +from abc import ABC, abstractmethod + +import pandas as pd + +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) + + +class GlobalContextBuilder(ABC): + """Base class for global-search context builders.""" + + @abstractmethod + def build_context( + self, conversation_history: ConversationHistory | None = None, **kwargs + ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: + """Build the context for the global search mode.""" + + +class LocalContextBuilder(ABC): + """Base class for local-search context builders.""" + + @abstractmethod + def build_context( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: + """Build the context for the local search mode.""" diff --git a/func-app/graphrag/query/context_builder/community_context.py b/func-app/graphrag/query/context_builder/community_context.py new file mode 100644 index 0000000000..ad345a2704 --- /dev/null +++ b/func-app/graphrag/query/context_builder/community_context.py @@ -0,0 +1,253 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Community Context.""" + +import logging +import random +from typing import Any, cast + +import pandas as pd +import tiktoken + +from graphrag.model import CommunityReport, Entity +from graphrag.query.llm.text_utils import num_tokens + +log = logging.getLogger(__name__) + + +def build_community_context( + community_reports: list[CommunityReport], + entities: list[Entity] | None = None, + token_encoder: tiktoken.Encoding | None = None, + use_community_summary: bool = True, + column_delimiter: str = "|", + shuffle_data: bool = True, + include_community_rank: bool = False, + min_community_rank: int = 0, + community_rank_name: str = "rank", + include_community_weight: bool = True, + community_weight_name: str = "occurrence weight", + normalize_community_weight: bool = True, + max_tokens: int = 8000, + single_batch: bool = True, + context_name: str = "Reports", + random_state: int = 86, + is_optimized_search: bool = False, +) -> tuple[str | list[str], dict[str, pd.DataFrame]]: + """ + Prepare community report data table as context data for system prompt. + + If entities are provided, the community weight is calculated as the count of text units associated with entities within the community. + + The calculated weight is added as an attribute to the community reports and added to the context data table. + """ + + def _is_included(report: CommunityReport) -> bool: + return report.rank is not None and report.rank >= min_community_rank + + def _get_header(attributes: list[str]) -> list[str]: + header = ["id", "title"] + attributes = [col for col in attributes if col not in header] + if not include_community_weight: + attributes = [col for col in attributes if col != community_weight_name] + header.extend(attributes) + header.append("summary" if use_community_summary else "content") + if include_community_rank: + header.append(community_rank_name) + return header + + def _report_context_text( + report: CommunityReport, attributes: list[str], + is_optimized_search: bool = False + ) -> tuple[str, list[str]]: + context: list[str] = [ + report.short_id if report.short_id else "", + report.title, + *[ + str(report.attributes.get(field, "")) if report.attributes else "" + for field in attributes + ], + ] + context.append(report.summary if use_community_summary else report.full_content) + if include_community_rank: + context.append(str(report.rank)) + result = column_delimiter.join(context) + "\n" + return result, context + + compute_community_weights = ( + entities + and len(community_reports) > 0 + and include_community_weight + and ( + community_reports[0].attributes is None + or community_weight_name not in community_reports[0].attributes + ) + ) + if compute_community_weights: + log.info("Computing community weights...") + community_reports = _compute_community_weights( + community_reports=community_reports, + entities=entities, + weight_attribute=community_weight_name, + normalize=normalize_community_weight, + ) + + selected_reports = [report for report in community_reports if _is_included(report)] + + if selected_reports is None or len(selected_reports) == 0: + return ([], {}) + + if shuffle_data: + random.seed(random_state) + random.shuffle(selected_reports) + + # "global" variables + attributes = ( + list(community_reports[0].attributes.keys()) + if community_reports[0].attributes + else [] + ) + header = _get_header(attributes) + all_context_text: list[str] = [] + all_context_records: list[pd.DataFrame] = [] + + # batch variables + batch_text: str = "" + batch_tokens: int = 0 + batch_records: list[list[str]] = [] + + def _init_batch() -> None: + nonlocal batch_text, batch_tokens, batch_records + batch_text = ( + f"-----{context_name}-----" + "\n" + column_delimiter.join(header) + "\n" + ) + batch_tokens = num_tokens(batch_text, token_encoder) + batch_records = [] + + def _cut_batch() -> None: + # convert the current context records to pandas dataframe and sort by weight and rank if exist + record_df = _convert_report_context_to_df( + context_records=batch_records, + header=header, + weight_column=community_weight_name + if entities and include_community_weight + else None, + rank_column=community_rank_name if include_community_rank else None, + ) + if len(record_df) == 0: + return + current_context_text = record_df.to_csv(index=False, sep=column_delimiter) + all_context_text.append(current_context_text) + all_context_records.append(record_df) + + # initialize the first batch + _init_batch() + + for report in selected_reports: + new_context_text, new_context = _report_context_text(report, attributes, is_optimized_search) + new_tokens = num_tokens(new_context_text, token_encoder) + + if batch_tokens + new_tokens > max_tokens: + # add the current batch to the context data and start a new batch if we are in multi-batch mode + _cut_batch() + if single_batch: + break + _init_batch() + + # add current report to the current batch + batch_text += new_context_text + batch_tokens += new_tokens + batch_records.append(new_context) + + # add the last batch if it has not been added + if batch_text not in all_context_text: + _cut_batch() + + if len(all_context_records) == 0: + log.warning( + "Warning: No community records added when building community context." + ) + return ([], {}) + + return all_context_text, { + context_name.lower(): pd.concat(all_context_records, ignore_index=True) + } + + +def _compute_community_weights( + community_reports: list[CommunityReport], + entities: list[Entity] | None, + weight_attribute: str = "occurrence", + normalize: bool = True, +) -> list[CommunityReport]: + """Calculate a community's weight as count of text units associated with entities within the community.""" + if not entities: + return community_reports + + community_text_units = {} + for entity in entities: + if entity.community_ids: + for community_id in entity.community_ids: + if community_id not in community_text_units: + community_text_units[community_id] = [] + community_text_units[community_id].extend(entity.text_unit_ids) + for report in community_reports: + if not report.attributes: + report.attributes = {} + report.attributes[weight_attribute] = len( + set(community_text_units.get(report.community_id, [])) + ) + if normalize: + # normalize by max weight + all_weights = [ + report.attributes[weight_attribute] + for report in community_reports + if report.attributes + ] + max_weight = max(all_weights) + for report in community_reports: + if report.attributes: + report.attributes[weight_attribute] = ( + report.attributes[weight_attribute] / max_weight + ) + return community_reports + + +def _rank_report_context( + report_df: pd.DataFrame, + weight_column: str | None = "occurrence weight", + rank_column: str | None = "rank", +) -> pd.DataFrame: + """Sort report context by community weight and rank if exist.""" + rank_attributes: list[str] = [] + if weight_column: + rank_attributes.append(weight_column) + report_df[weight_column] = report_df[weight_column].astype(float) + if rank_column: + rank_attributes.append(rank_column) + report_df[rank_column] = report_df[rank_column].astype(float) + if len(rank_attributes) > 0: + report_df.sort_values(by=rank_attributes, ascending=False, inplace=True) + return report_df + + +def _convert_report_context_to_df( + context_records: list[list[str]], + header: list[str], + weight_column: str | None = None, + rank_column: str | None = None, +) -> pd.DataFrame: + """Convert report context records to pandas dataframe and sort by weight and rank if exist.""" + if len(context_records) == 0: + return pd.DataFrame() + + record_df = pd.DataFrame( + context_records, + columns=cast(Any, header), + ) + return _rank_report_context( + report_df=record_df, + weight_column=weight_column, + rank_column=rank_column, + ) diff --git a/func-app/graphrag/query/context_builder/conversation_history.py b/func-app/graphrag/query/context_builder/conversation_history.py new file mode 100644 index 0000000000..33f516dbd4 --- /dev/null +++ b/func-app/graphrag/query/context_builder/conversation_history.py @@ -0,0 +1,212 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Classes for storing and managing conversation history.""" + +from dataclasses import dataclass +from enum import Enum + +import pandas as pd +import tiktoken + +from graphrag.query.llm.text_utils import num_tokens + +""" +Enum for conversation roles +""" + + +class ConversationRole(str, Enum): + """Enum for conversation roles.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + + @staticmethod + def from_string(value: str) -> "ConversationRole": + """Convert string to ConversationRole.""" + if value == "system": + return ConversationRole.SYSTEM + if value == "user": + return ConversationRole.USER + if value == "assistant": + return ConversationRole.ASSISTANT + + msg = f"Invalid Role: {value}" + raise ValueError(msg) + + def __str__(self) -> str: + """Return string representation of the enum value.""" + return self.value + + +""" +Data class for storing a single conversation turn +""" + + +@dataclass +class ConversationTurn: + """Data class for storing a single conversation turn.""" + + role: ConversationRole + content: str + + def __str__(self) -> str: + """Return string representation of the conversation turn.""" + return f"{self.role}: {self.content}" + + +@dataclass +class QATurn: + """ + Data class for storing a QA turn. + + A QA turn contains a user question and one more multiple assistant answers. + """ + + user_query: ConversationTurn + assistant_answers: list[ConversationTurn] | None = None + + def get_answer_text(self) -> str | None: + """Get the text of the assistant answers.""" + return ( + "\n".join([answer.content for answer in self.assistant_answers]) + if self.assistant_answers + else None + ) + + def __str__(self) -> str: + """Return string representation of the QA turn.""" + answers = self.get_answer_text() + return ( + f"Question: {self.user_query.content}\nAnswer: {answers}" + if answers + else f"Question: {self.user_query.content}" + ) + + +class ConversationHistory: + """Class for storing a conversation history.""" + + turns: list[ConversationTurn] + + def __init__(self): + self.turns = [] + + @classmethod + def from_list( + cls, conversation_turns: list[dict[str, str]] + ) -> "ConversationHistory": + """ + Create a conversation history from a list of conversation turns. + + Each turn is a dictionary in the form of {"role": "", "content": ""} + """ + history = cls() + for turn in conversation_turns: + history.turns.append( + ConversationTurn( + role=ConversationRole.from_string( + turn.get("role", ConversationRole.USER) + ), + content=turn.get("content", ""), + ) + ) + return history + + def add_turn(self, role: ConversationRole, content: str): + """Add a new turn to the conversation history.""" + self.turns.append(ConversationTurn(role=role, content=content)) + + def to_qa_turns(self) -> list[QATurn]: + """Convert conversation history to a list of QA turns.""" + qa_turns = list[QATurn]() + current_qa_turn = None + for turn in self.turns: + if turn.role == ConversationRole.USER: + if current_qa_turn: + qa_turns.append(current_qa_turn) + current_qa_turn = QATurn(user_query=turn, assistant_answers=[]) + else: + if current_qa_turn: + current_qa_turn.assistant_answers.append(turn) # type: ignore + if current_qa_turn: + qa_turns.append(current_qa_turn) + return qa_turns + + def get_user_turns(self, max_user_turns: int | None = 1) -> list[str]: + """Get the last user turns in the conversation history.""" + user_turns = [] + for turn in self.turns[::-1]: + if turn.role == ConversationRole.USER: + user_turns.append(turn.content) + if max_user_turns and len(user_turns) >= max_user_turns: + break + return user_turns + + def build_context( + self, + token_encoder: tiktoken.Encoding | None = None, + include_user_turns_only: bool = True, + max_qa_turns: int | None = 5, + max_tokens: int = 8000, + recency_bias: bool = True, + column_delimiter: str = "|", + context_name: str = "Conversation History", + ) -> tuple[str, dict[str, pd.DataFrame]]: + """ + Prepare conversation history as context data for system prompt. + + Parameters + ---------- + user_queries_only: If True, only user queries (not assistant responses) will be included in the context, default is True. + max_qa_turns: Maximum number of QA turns to include in the context, default is 1. + recency_bias: If True, reverse the order of the conversation history to ensure last QA got prioritized. + column_delimiter: Delimiter to use for separating columns in the context data, default is "|". + context_name: Name of the context, default is "Conversation History". + + """ + qa_turns = self.to_qa_turns() + if include_user_turns_only: + qa_turns = [ + QATurn(user_query=qa_turn.user_query, assistant_answers=None) + for qa_turn in qa_turns + ] + if recency_bias: + qa_turns = qa_turns[::-1] + if max_qa_turns and len(qa_turns) > max_qa_turns: + qa_turns = qa_turns[:max_qa_turns] + + # build context for qa turns + # add context header + if len(qa_turns) == 0 or not qa_turns: + return ("", {context_name: pd.DataFrame()}) + + # add table header + header = f"-----{context_name}-----" + "\n" + + turn_list = [] + current_context_df = pd.DataFrame() + for turn in qa_turns: + turn_list.append({ + "turn": ConversationRole.USER.__str__(), + "content": turn.user_query.content, + }) + if turn.assistant_answers: + turn_list.append({ + "turn": ConversationRole.ASSISTANT.__str__(), + "content": turn.get_answer_text(), + }) + + context_df = pd.DataFrame(turn_list) + context_text = header + context_df.to_csv(sep=column_delimiter, index=False) + if num_tokens(context_text, token_encoder) > max_tokens: + break + + current_context_df = context_df + context_text = header + current_context_df.to_csv( + sep=column_delimiter, index=False + ) + return (context_text, {context_name.lower(): current_context_df}) diff --git a/func-app/graphrag/query/context_builder/entity_extraction.py b/func-app/graphrag/query/context_builder/entity_extraction.py new file mode 100644 index 0000000000..037da80932 --- /dev/null +++ b/func-app/graphrag/query/context_builder/entity_extraction.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Orchestration Context Builders.""" + +from enum import Enum + +from graphrag.model import Entity, Relationship +from graphrag.query.input.retrieval.entities import ( + get_entity_by_key, + get_entity_by_name, +) +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.vector_stores import BaseVectorStore + + +class EntityVectorStoreKey(str, Enum): + """Keys used as ids in the entity embedding vectorstores.""" + + ID = "id" + TITLE = "title" + + @staticmethod + def from_string(value: str) -> "EntityVectorStoreKey": + """Convert string to EntityVectorStoreKey.""" + if value == "id": + return EntityVectorStoreKey.ID + if value == "title": + return EntityVectorStoreKey.TITLE + + msg = f"Invalid EntityVectorStoreKey: {value}" + raise ValueError(msg) + +def map_query_to_entities_in_place( + query: str, + text_embedding_vectorstore: BaseVectorStore, + text_embedder: BaseTextEmbedding, + k: int = 10, + oversample_scaler: int = 2, +) -> list[Entity]: + """Extract entities that match a given query using semantic similarity of text embeddings of query and entity descriptions.""" + # get entities with highest semantic similarity to query + # oversample to account for excluded entities + search_results = text_embedding_vectorstore.get_extracted_entities( + text=query, + text_embedder=lambda t: text_embedder.embed(t), + k=k * oversample_scaler, + ) + import ast + for result in search_results: + result.community_ids = ast.literal_eval(result.community_ids) + return search_results + +def map_query_to_entities( + query: str, + text_embedding_vectorstore: BaseVectorStore, + text_embedder: BaseTextEmbedding, + all_entities: list[Entity], + embedding_vectorstore_key: str = EntityVectorStoreKey.ID, + include_entity_names: list[str] | None = None, + exclude_entity_names: list[str] | None = None, + k: int = 10, + oversample_scaler: int = 2, +) -> list[Entity]: + """Extract entities that match a given query using semantic similarity of text embeddings of query and entity descriptions.""" + if all_entities == []: + return map_query_to_entities_in_place( + query, + text_embedding_vectorstore, + text_embedder, + k, + oversample_scaler, + ) + + if include_entity_names is None: + include_entity_names = [] + if exclude_entity_names is None: + exclude_entity_names = [] + matched_entities = [] + if query != "": + # get entities with highest semantic similarity to query + # oversample to account for excluded entities + search_results = text_embedding_vectorstore.similarity_search_by_text( + text=query, + text_embedder=lambda t: text_embedder.embed(t), + k=k * oversample_scaler, + ) + for result in search_results: + matched = get_entity_by_key( + entities=all_entities, + key=embedding_vectorstore_key, + value=result.document.id, + ) + if matched: + matched_entities.append(matched) + else: + all_entities.sort(key=lambda x: x.rank if x.rank else 0, reverse=True) + matched_entities = all_entities[:k] + + # filter out excluded entities + if exclude_entity_names: + matched_entities = [ + entity + for entity in matched_entities + if entity.title not in exclude_entity_names + ] + + # add entities in the include_entity list + included_entities = [] + for entity_name in include_entity_names: + included_entities.extend(get_entity_by_name(all_entities, entity_name)) + return included_entities + matched_entities + + +def find_nearest_neighbors_by_graph_embeddings( + entity_id: str, + graph_embedding_vectorstore: BaseVectorStore, + all_entities: list[Entity], + exclude_entity_names: list[str] | None = None, + embedding_vectorstore_key: str = EntityVectorStoreKey.ID, + k: int = 10, + oversample_scaler: int = 2, +) -> list[Entity]: + """Retrieve related entities by graph embeddings.""" + if exclude_entity_names is None: + exclude_entity_names = [] + # find nearest neighbors of this entity using graph embedding + query_entity = get_entity_by_key( + entities=all_entities, key=embedding_vectorstore_key, value=entity_id + ) + query_embedding = query_entity.graph_embedding if query_entity else None + + # oversample to account for excluded entities + if query_embedding: + matched_entities = [] + search_results = graph_embedding_vectorstore.similarity_search_by_vector( + query_embedding=query_embedding, k=k * oversample_scaler + ) + for result in search_results: + matched = get_entity_by_key( + entities=all_entities, + key=embedding_vectorstore_key, + value=result.document.id, + ) + if matched: + matched_entities.append(matched) + + # filter out excluded entities + if exclude_entity_names: + matched_entities = [ + entity + for entity in matched_entities + if entity.title not in exclude_entity_names + ] + matched_entities.sort(key=lambda x: x.rank, reverse=True) + return matched_entities[:k] + + return [] + + +def find_nearest_neighbors_by_entity_rank( + entity_name: str, + all_entities: list[Entity], + all_relationships: list[Relationship], + exclude_entity_names: list[str] | None = None, + k: int | None = 10, +) -> list[Entity]: + """Retrieve entities that have direct connections with the target entity, sorted by entity rank.""" + if exclude_entity_names is None: + exclude_entity_names = [] + entity_relationships = [ + rel + for rel in all_relationships + if rel.source == entity_name or rel.target == entity_name + ] + source_entity_names = {rel.source for rel in entity_relationships} + target_entity_names = {rel.target for rel in entity_relationships} + related_entity_names = (source_entity_names.union(target_entity_names)).difference( + set(exclude_entity_names) + ) + top_relations = [ + entity for entity in all_entities if entity.title in related_entity_names + ] + top_relations.sort(key=lambda x: x.rank if x.rank else 0, reverse=True) + if k: + return top_relations[:k] + return top_relations diff --git a/func-app/graphrag/query/context_builder/local_context.py b/func-app/graphrag/query/context_builder/local_context.py new file mode 100644 index 0000000000..b48bf0bf96 --- /dev/null +++ b/func-app/graphrag/query/context_builder/local_context.py @@ -0,0 +1,360 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Local Context Builder.""" + +from collections import defaultdict +from typing import Any, cast + +import pandas as pd +from common.graph_db_client import GraphDBClient +import tiktoken + +from graphrag.model import Covariate, Entity, Relationship +from graphrag.query.input.retrieval.covariates import ( + get_candidate_covariates, + to_covariate_dataframe, +) +from graphrag.query.input.retrieval.entities import to_entity_dataframe +from graphrag.query.input.retrieval.relationships import ( + get_candidate_relationships, + get_entities_from_relationships, + get_in_network_relationships, + get_out_network_relationships, + to_relationship_dataframe, +) +from graphrag.query.llm.text_utils import num_tokens + + +def build_entity_context( + selected_entities: list[Entity], + token_encoder: tiktoken.Encoding | None = None, + max_tokens: int = 8000, + include_entity_rank: bool = True, + rank_description: str = "number of relationships", + column_delimiter: str = "|", + context_name="Entities", + is_optimized_search: bool = False +) -> tuple[str, pd.DataFrame]: + """Prepare entity data table as context data for system prompt.""" + if len(selected_entities) == 0: + return "", pd.DataFrame() + + # add headers + current_context_text = f"-----{context_name}-----" + "\n" + header = ["id", "entity", "description"] + if include_entity_rank: + header.append(rank_description) + attribute_cols = ( + list(selected_entities[0].attributes.keys()) + if selected_entities[0].attributes + else [] + ) + header.extend(attribute_cols) + current_context_text += column_delimiter.join(header) + "\n" + current_tokens = num_tokens(current_context_text, token_encoder) + + all_context_records = [header] + for entity in selected_entities: + new_context = [ + entity.short_id if entity.short_id else "", + entity.title, + entity.description if entity.description else "", + ] + if include_entity_rank: + new_context.append(str(entity.rank)) + for field in attribute_cols: + field_value = ( + str(entity.attributes.get(field)) + if entity.attributes and entity.attributes.get(field) + else "" + ) + new_context.append(field_value) + new_tokens: int = 0 + if not is_optimized_search: + new_context_text = column_delimiter.join(new_context) + "\n" + new_tokens = num_tokens(new_context_text, token_encoder) + if current_tokens + new_tokens > max_tokens: + break + current_context_text += new_context_text + all_context_records.append(new_context) + current_tokens += new_tokens + + if len(all_context_records) > 1: + record_df = pd.DataFrame( + all_context_records[1:], columns=cast(Any, all_context_records[0]) + ) + else: + record_df = pd.DataFrame() + + return current_context_text, record_df + + +def build_covariates_context( + selected_entities: list[Entity], + covariates: list[Covariate], + token_encoder: tiktoken.Encoding | None = None, + max_tokens: int = 8000, + column_delimiter: str = "|", + context_name: str = "Covariates", + is_optimized_search: bool = False +) -> tuple[str, pd.DataFrame]: + """Prepare covariate data tables as context data for system prompt.""" + # create an empty list of covariates + if len(selected_entities) == 0 or len(covariates) == 0: + return "", pd.DataFrame() + + selected_covariates = list[Covariate]() + record_df = pd.DataFrame() + + # add context header + current_context_text = f"-----{context_name}-----" + "\n" + + # add header + header = ["id", "entity"] + attributes = covariates[0].attributes or {} if len(covariates) > 0 else {} + attribute_cols = list(attributes.keys()) if len(covariates) > 0 else [] + header.extend(attribute_cols) + current_context_text += column_delimiter.join(header) + "\n" + current_tokens = num_tokens(current_context_text, token_encoder) + + all_context_records = [header] + for entity in selected_entities: + selected_covariates.extend([ + cov for cov in covariates if cov.subject_id == entity.title + ]) + + for covariate in selected_covariates: + new_context = [ + covariate.short_id if covariate.short_id else "", + covariate.subject_id, + ] + for field in attribute_cols: + field_value = ( + str(covariate.attributes.get(field)) + if covariate.attributes and covariate.attributes.get(field) + else "" + ) + new_context.append(field_value) + + new_context_text = column_delimiter.join(new_context) + "\n" + new_tokens = num_tokens(new_context_text, token_encoder) + if current_tokens + new_tokens > max_tokens: + break + current_context_text += new_context_text + all_context_records.append(new_context) + current_tokens += new_tokens + + if len(all_context_records) > 1: + record_df = pd.DataFrame( + all_context_records[1:], columns=cast(Any, all_context_records[0]) + ) + else: + record_df = pd.DataFrame() + + return current_context_text, record_df + + +def build_relationship_context( + selected_entities: list[Entity], + relationships: list[Relationship], + token_encoder: tiktoken.Encoding | None = None, + include_relationship_weight: bool = False, + max_tokens: int = 8000, + top_k_relationships: int = 10, + relationship_ranking_attribute: str = "rank", + column_delimiter: str = "|", + context_name: str = "Relationships", + is_optimized_search: bool = False, + graphdb_client: GraphDBClient|None=None, +) -> tuple[str, pd.DataFrame]: + """Prepare relationship data tables as context data for system prompt.""" + selected_relationships = _filter_relationships( + selected_entities=selected_entities, + relationships=relationships, + top_k_relationships=top_k_relationships, + relationship_ranking_attribute=relationship_ranking_attribute, + graphdb_client=graphdb_client, + ) + + if len(selected_entities) == 0 or len(selected_relationships) == 0: + return "", pd.DataFrame() + + # add headers + current_context_text = f"-----{context_name}-----" + "\n" + header = ["id", "source", "target", "description"] + if include_relationship_weight: + header.append("weight") + attribute_cols = ( + list(selected_relationships[0].attributes.keys()) + if selected_relationships[0].attributes + else [] + ) + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + current_context_text += column_delimiter.join(header) + "\n" + current_tokens = num_tokens(current_context_text, token_encoder) + + all_context_records = [header] + for rel in selected_relationships: + new_context = [ + rel.short_id if rel.short_id else "", + rel.source, + rel.target, + rel.description if rel.description else "", + ] + if include_relationship_weight: + new_context.append(str(rel.weight if rel.weight else "")) + for field in attribute_cols: + field_value = ( + str(rel.attributes.get(field)) + if rel.attributes and rel.attributes.get(field) + else "" + ) + new_context.append(field_value) + new_context_text = "" + new_tokens = 0 + if not is_optimized_search: + new_context_text = column_delimiter.join(new_context) + "\n" + new_tokens = num_tokens(new_context_text, token_encoder) + if current_tokens + new_tokens > max_tokens: #General: There could be side impact of generating huge number of relationships + break + current_context_text += new_context_text + all_context_records.append(new_context) + current_tokens += new_tokens + + if len(all_context_records) > 1: + record_df = pd.DataFrame( + all_context_records[1:], columns=cast(Any, all_context_records[0]) + ) + else: + record_df = pd.DataFrame() + + return current_context_text, record_df + + +def _filter_relationships( + selected_entities: list[Entity], + relationships: list[Relationship], + top_k_relationships: int = 10, + relationship_ranking_attribute: str = "rank", + graphdb_client: GraphDBClient|None=None, +) -> list[Relationship]: + """Filter and sort relationships based on a set of selected entities and a ranking attribute.""" + # First priority: in-network relationships (i.e. relationships between selected entities) + in_network_relationships = get_in_network_relationships( + selected_entities=selected_entities, + relationships=relationships, + ranking_attribute=relationship_ranking_attribute, + graphdb_client=graphdb_client, + ) + + # Second priority - out-of-network relationships + # (i.e. relationships between selected entities and other entities that are not within the selected entities) + out_network_relationships = get_out_network_relationships( + selected_entities=selected_entities, + relationships=relationships, + ranking_attribute=relationship_ranking_attribute, + graphdb_client=graphdb_client, + ) + if len(out_network_relationships) <= 1: + return in_network_relationships + out_network_relationships + + # within out-of-network relationships, prioritize mutual relationships + # (i.e. relationships with out-network entities that are shared with multiple selected entities) + selected_entity_names = [entity.title for entity in selected_entities] + out_network_source_names = [ + relationship.source + for relationship in out_network_relationships + if relationship.source not in selected_entity_names + ] + out_network_target_names = [ + relationship.target + for relationship in out_network_relationships + if relationship.target not in selected_entity_names + ] + out_network_entity_names = list( + set(out_network_source_names + out_network_target_names) + ) + out_network_entity_links = defaultdict(int) + for entity_name in out_network_entity_names: + targets = [ + relationship.target + for relationship in out_network_relationships + if relationship.source == entity_name + ] + sources = [ + relationship.source + for relationship in out_network_relationships + if relationship.target == entity_name + ] + out_network_entity_links[entity_name] = len(set(targets + sources)) + + # sort out-network relationships by number of links and rank_attributes + for rel in out_network_relationships: + if rel.attributes is None: + rel.attributes = {} + rel.attributes["links"] = ( + out_network_entity_links[rel.source] + if rel.source in out_network_entity_links + else out_network_entity_links[rel.target] + ) + + # sort by attributes[links] first, then by ranking_attribute + if relationship_ranking_attribute == "weight": + out_network_relationships.sort( + key=lambda x: (x.attributes["links"], x.weight), # type: ignore + reverse=True, # type: ignore + ) + else: + out_network_relationships.sort( + key=lambda x: ( + x.attributes["links"], # type: ignore + x.attributes[relationship_ranking_attribute], # type: ignore + ), # type: ignore + reverse=True, + ) + + relationship_budget = top_k_relationships * len(selected_entities) + return in_network_relationships + out_network_relationships[:relationship_budget] + + +def get_candidate_context( + selected_entities: list[Entity], + entities: list[Entity], + relationships: list[Relationship], + covariates: dict[str, list[Covariate]], + include_entity_rank: bool = True, + entity_rank_description: str = "number of relationships", + include_relationship_weight: bool = False, +) -> dict[str, pd.DataFrame]: + """Prepare entity, relationship, and covariate data tables as context data for system prompt.""" + candidate_context = {} + candidate_relationships = get_candidate_relationships( + selected_entities=selected_entities, + relationships=relationships, + ) + candidate_context["relationships"] = to_relationship_dataframe( + relationships=candidate_relationships, + include_relationship_weight=include_relationship_weight, + ) + candidate_entities = get_entities_from_relationships( + relationships=candidate_relationships, entities=entities + ) + candidate_context["entities"] = to_entity_dataframe( + entities=candidate_entities, + include_entity_rank=include_entity_rank, + rank_description=entity_rank_description, + ) + + for covariate in covariates: + candidate_covariates = get_candidate_covariates( + selected_entities=selected_entities, + covariates=covariates[covariate], + ) + candidate_context[covariate.lower()] = to_covariate_dataframe( + candidate_covariates + ) + + return candidate_context diff --git a/func-app/graphrag/query/context_builder/source_context.py b/func-app/graphrag/query/context_builder/source_context.py new file mode 100644 index 0000000000..99b7791ca6 --- /dev/null +++ b/func-app/graphrag/query/context_builder/source_context.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Context Build utility methods.""" + +import random +from typing import Any, cast + +import pandas as pd +import tiktoken + +from graphrag.model import Entity, Relationship, TextUnit +from graphrag.query.llm.text_utils import num_tokens + +""" +Contain util functions to build text unit context for the search's system prompt +""" + + +def build_text_unit_context( + text_units: list[TextUnit], + token_encoder: tiktoken.Encoding | None = None, + column_delimiter: str = "|", + shuffle_data: bool = True, + max_tokens: int = 8000, + context_name: str = "Sources", + random_state: int = 86, +) -> tuple[str, dict[str, pd.DataFrame]]: + """Prepare text-unit data table as context data for system prompt.""" + if text_units is None or len(text_units) == 0: + return ("", {}) + + if shuffle_data: + random.seed(random_state) + random.shuffle(text_units) + + # add context header + current_context_text = f"-----{context_name}-----" + "\n" + + # add header + header = ["id", "text"] + attribute_cols = ( + list(text_units[0].attributes.keys()) if text_units[0].attributes else [] + ) + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + current_context_text += column_delimiter.join(header) + "\n" + current_tokens = num_tokens(current_context_text, token_encoder) + all_context_records = [header] + + for unit in text_units: + new_context = [ + unit.short_id, + unit.text, + *[ + str(unit.attributes.get(field, "")) if unit.attributes else "" + for field in attribute_cols + ], + ] + new_context_text = column_delimiter.join(new_context) + "\n" + new_tokens = num_tokens(new_context_text, token_encoder) + + if current_tokens + new_tokens > max_tokens: + break + + current_context_text += new_context_text + all_context_records.append(new_context) + current_tokens += new_tokens + + if len(all_context_records) > 1: + record_df = pd.DataFrame( + all_context_records[1:], columns=cast(Any, all_context_records[0]) + ) + else: + record_df = pd.DataFrame() + return current_context_text, {context_name.lower(): record_df} + + +def count_relationships( + text_unit: TextUnit, entity: Entity, relationships: dict[str, Relationship] +) -> int: + """Count the number of relationships of the selected entity that are associated with the text unit.""" + matching_relationships = list[Relationship]() + if text_unit.relationship_ids is None: + entity_relationships = [ + rel + for rel in relationships.values() + if rel.source == entity.title or rel.target == entity.title + ] + entity_relationships = [ + rel for rel in entity_relationships if rel.text_unit_ids + ] + matching_relationships = [ + rel + for rel in entity_relationships + if text_unit.id in rel.text_unit_ids # type: ignore + ] # type: ignore + else: + text_unit_relationships = [ + relationships[rel_id] + for rel_id in text_unit.relationship_ids + if rel_id in relationships + ] + matching_relationships = [ + rel + for rel in text_unit_relationships + if rel.source == entity.title or rel.target == entity.title + ] + return len(matching_relationships) diff --git a/func-app/graphrag/query/factories.py b/func-app/graphrag/query/factories.py new file mode 100644 index 0000000000..28caf61bb0 --- /dev/null +++ b/func-app/graphrag/query/factories.py @@ -0,0 +1,211 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Query Factory methods to support CLI.""" + +from graphrag.config.models.graphdb_config import GraphDBConfig +import tiktoken +from azure.identity import ManagedIdentityCredential, get_bearer_token_provider + +from graphrag.config import ( + GraphRagConfig, + LLMType, +) +from graphrag.model import ( + CommunityReport, + Covariate, + Entity, + Relationship, + TextUnit, +) +from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey +from graphrag.query.llm.oai.chat_openai import ChatOpenAI +from graphrag.query.llm.oai.embedding import OpenAIEmbedding +from graphrag.query.llm.oai.typing import OpenaiApiType +from graphrag.query.structured_search.global_search.community_context import ( + GlobalCommunityContext, +) +from graphrag.query.structured_search.global_search.search import GlobalSearch +from graphrag.query.structured_search.local_search.mixed_context import ( + LocalSearchMixedContext, +) +from graphrag.query.structured_search.local_search.search import LocalSearch +from graphrag.vector_stores import BaseVectorStore + + +def get_llm(config: GraphRagConfig) -> ChatOpenAI: + """Get the LLM client.""" + is_azure_client = ( + config.llm.type == LLMType.AzureOpenAIChat + or config.llm.type == LLMType.AzureOpenAI + ) + debug_llm_key = config.llm.api_key or "" + llm_debug_info = { + **config.llm.model_dump(), + "api_key": f"REDACTED,len={len(debug_llm_key)}", + } + if config.llm.cognitive_services_endpoint is None: + cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default" + else: + cognitive_services_endpoint = config.llm.cognitive_services_endpoint + print(f"creating llm client with {llm_debug_info}") # noqa T201 + return ChatOpenAI( + api_key=config.llm.api_key, + azure_ad_token_provider=( + get_bearer_token_provider( + ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint + ) + if is_azure_client and not config.llm.api_key + else None + ), + api_base=config.llm.api_base, + organization=config.llm.organization, + model=config.llm.model, + api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI, + deployment_name=config.llm.deployment_name, + api_version=config.llm.api_version, + max_retries=config.llm.max_retries, + ) + + +def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding: + """Get the LLM client for embeddings.""" + is_azure_client = config.embeddings.llm.type == LLMType.AzureOpenAIEmbedding + debug_embedding_api_key = config.embeddings.llm.api_key or "" + llm_debug_info = { + **config.embeddings.llm.model_dump(), + "api_key": f"REDACTED,len={len(debug_embedding_api_key)}", + } + if config.embeddings.llm.cognitive_services_endpoint is None: + cognitive_services_endpoint = "https://cognitiveservices.azure.com/.default" + else: + cognitive_services_endpoint = config.embeddings.llm.cognitive_services_endpoint + print(f"creating embedding llm client with {llm_debug_info}") # noqa T201 + return OpenAIEmbedding( + api_key=config.embeddings.llm.api_key, + azure_ad_token_provider=( + get_bearer_token_provider( + ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint + ) + if is_azure_client and not config.embeddings.llm.api_key + else None + ), + api_base=config.embeddings.llm.api_base, + organization=config.llm.organization, + api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI, + model=config.embeddings.llm.model, + deployment_name=config.embeddings.llm.deployment_name, + api_version=config.embeddings.llm.api_version, + max_retries=config.embeddings.llm.max_retries, + ) + + +def get_local_search_engine( + config: GraphRagConfig, + reports: list[CommunityReport], + text_units: list[TextUnit], + entities: list[Entity], + relationships: list[Relationship], + covariates: dict[str, list[Covariate]], + response_type: str, + description_embedding_store: BaseVectorStore, + context_id: str, + is_optimized_search: bool = False, + use_kusto_community_reports: bool = False, + graphdb_config: GraphDBConfig|None = None, +) -> LocalSearch: + """Create a local search engine based on data + configuration.""" + llm = get_llm(config) + text_embedder = get_text_embedder(config) + token_encoder = tiktoken.get_encoding(config.encoding_model) + + ls_config = config.local_search + + return LocalSearch( + llm=llm, + context_builder=LocalSearchMixedContext( + community_reports=reports, + text_units=text_units, + entities=entities, + relationships=relationships, + covariates=covariates, + entity_text_embeddings=description_embedding_store, + embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE + text_embedder=text_embedder, + token_encoder=token_encoder, + is_optimized_search= is_optimized_search, + use_kusto_community_reports=use_kusto_community_reports, + graphdb_config=graphdb_config, + context_id=context_id, + ), + token_encoder=token_encoder, + llm_params={ + "max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500) + "temperature": ls_config.temperature, + "top_p": ls_config.top_p, + "n": ls_config.n, + }, + context_builder_params={ + "text_unit_prop": ls_config.text_unit_prop, + "community_prop": ls_config.community_prop, + "conversation_history_max_turns": ls_config.conversation_history_max_turns, + "conversation_history_user_turns_only": True, + "top_k_mapped_entities": ls_config.top_k_entities, + "top_k_relationships": ls_config.top_k_relationships, + "include_entity_rank": True, + "include_relationship_weight": True, + "include_community_rank": False, + "return_candidate_context": False, + "embedding_vectorstore_key": EntityVectorStoreKey.ID, # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids + "max_tokens": ls_config.max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + }, + response_type=response_type, + ) + + +def get_global_search_engine( + config: GraphRagConfig, + reports: list[CommunityReport], + entities: list[Entity], + response_type: str, +): + """Create a global search engine based on data + configuration.""" + token_encoder = tiktoken.get_encoding(config.encoding_model) + gs_config = config.global_search + + return GlobalSearch( + llm=get_llm(config), + context_builder=GlobalCommunityContext( + community_reports=reports, entities=entities, token_encoder=token_encoder + ), + token_encoder=token_encoder, + max_data_tokens=gs_config.data_max_tokens, + map_llm_params={ + "max_tokens": gs_config.map_max_tokens, + "temperature": gs_config.temperature, + "top_p": gs_config.top_p, + "n": gs_config.n, + }, + reduce_llm_params={ + "max_tokens": gs_config.reduce_max_tokens, + "temperature": gs_config.temperature, + "top_p": gs_config.top_p, + "n": gs_config.n, + }, + allow_general_knowledge=False, + json_mode=False, + context_builder_params={ + "use_community_summary": False, + "shuffle_data": True, + "include_community_rank": True, + "min_community_rank": 0, + "community_rank_name": "rank", + "include_community_weight": True, + "community_weight_name": "occurrence weight", + "normalize_community_weight": True, + "max_tokens": gs_config.max_tokens, + "context_name": "Reports", + }, + concurrent_coroutines=gs_config.concurrency, + response_type=response_type, + ) diff --git a/func-app/graphrag/query/indexer_adapters.py b/func-app/graphrag/query/indexer_adapters.py new file mode 100644 index 0000000000..101fc16f9c --- /dev/null +++ b/func-app/graphrag/query/indexer_adapters.py @@ -0,0 +1,159 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Indexing-Engine to Query Read Adapters. + +The parts of these functions that do type adaptation, renaming, collating, etc. should eventually go away. +Ideally this is just a straight read-thorugh into the object model. +""" + +from typing import cast + +import pandas as pd + +from graphrag.model import CommunityReport, Covariate, Entity, Relationship, TextUnit +from graphrag.query.input.loaders.dfs import ( + read_community_reports, + read_covariates, + read_entities, + read_relationships, + read_text_units, +) + +from graphrag.vector_stores import VectorStoreFactory, VectorStoreType + +def read_indexer_text_units(final_text_units: pd.DataFrame) -> list[TextUnit]: + """Read in the Text Units from the raw indexing outputs.""" + return read_text_units( + df=final_text_units, + short_id_col=None, + # expects a covariate map of type -> ids + covariates_col=None, + ) + + +def read_indexer_covariates(final_covariates: pd.DataFrame) -> list[Covariate]: + """Read in the Claims from the raw indexing outputs.""" + covariate_df = final_covariates + covariate_df["id"] = covariate_df["id"].astype(str) + return read_covariates( + df=covariate_df, + short_id_col="human_readable_id", + attributes_cols=[ + "object_id", + "status", + "start_date", + "end_date", + "description", + ], + text_unit_ids_col=None, + ) + +# GraphDB: read relationshiops from the graph db. +def read_indexer_relationships(final_relationships: pd.DataFrame) -> list[Relationship]: + """Read in the Relationships from the raw indexing outputs.""" + return read_relationships( + df=final_relationships, + short_id_col="human_readable_id", + description_embedding_col=None, + document_ids_col=None, + attributes_cols=["rank"], + ) + +def kt_read_indexer_reports( + vs: VectorStoreType.Kusto, + community_level: int, +) -> list[CommunityReport]: + + vs.client.execute(vs.database,'.drop table interm_rep ifexists') + + cmd=f''' + .set interm_rep <| (create_final_community_reports | where level <= 2 | + join kind=inner (create_final_nodes | + where level <= 2 | summarize community=max(community) by ['title'] | summarize by community ) + on community | project-away community1) + ''' + + res=vs.client.execute(vs.database,cmd) + return True #TODO: error checking should be added later + +def read_indexer_reports( + final_community_reports: pd.DataFrame, + final_nodes: pd.DataFrame, + community_level: int, +) -> list[CommunityReport]: + """Read in the Community Reports from the raw indexing outputs.""" + report_df = final_community_reports + entity_df = final_nodes + entity_df = _filter_under_community_level(entity_df, community_level) + entity_df["community"] = entity_df["community"].fillna(-1) + entity_df["community"] = entity_df["community"].astype(int) + + entity_df = entity_df.groupby(["title"]).agg({"community": "max"}).reset_index() + entity_df["community"] = entity_df["community"].astype(str) + filtered_community_df = entity_df["community"].drop_duplicates() + + report_df = _filter_under_community_level(report_df, community_level) + report_df = report_df.merge(filtered_community_df, on="community", how="inner") + report_df = report_df.drop_duplicates(subset=["community"]) + + return read_community_reports( + df=report_df, + id_col="community", + short_id_col="community", + summary_embedding_col=None, + content_embedding_col=None, + ) + + +def read_indexer_entities( + final_nodes: pd.DataFrame, + final_entities: pd.DataFrame, + community_level: int, +) -> list[Entity]: + """Read in the Entities from the raw indexing outputs.""" + entity_df = final_nodes + entity_embedding_df = final_entities + + entity_df = _filter_under_community_level(entity_df, community_level) + entity_df = cast(pd.DataFrame, entity_df[["title", "degree", "community"]]).rename( + columns={"title": "name", "degree": "rank"} + ) + + entity_df["community"] = entity_df["community"].fillna(-1) + entity_df["community"] = entity_df["community"].astype(int) + entity_df["rank"] = entity_df["rank"].astype(int) + + # for duplicate entities, keep the one with the highest community level + entity_df = ( + entity_df.groupby(["name", "rank"]).agg({"community": "max"}).reset_index() + ) + entity_df["community"] = entity_df["community"].apply(lambda x: [str(x)]) + entity_df = entity_df.merge( + entity_embedding_df, on="name", how="inner" + ).drop_duplicates(subset=["name"]) + + # read entity dataframe to knowledge model objects + return read_entities( + df=entity_df, + id_col="id", + title_col="name", + type_col="type", + short_id_col="human_readable_id", + description_col="description", + community_col="community", + rank_col="rank", + name_embedding_col=None, + description_embedding_col="description_embedding", + graph_embedding_col=None, + text_unit_ids_col="text_unit_ids", + document_ids_col=None, + ) + + +def _filter_under_community_level( + df: pd.DataFrame, community_level: int +) -> pd.DataFrame: + return cast( + pd.DataFrame, + df[df.level <= community_level], + ) diff --git a/func-app/graphrag/query/input/__init__.py b/func-app/graphrag/query/input/__init__.py new file mode 100644 index 0000000000..94ae973477 --- /dev/null +++ b/func-app/graphrag/query/input/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG Orchestration Inputs.""" diff --git a/func-app/graphrag/query/input/loaders/__init__.py b/func-app/graphrag/query/input/loaders/__init__.py new file mode 100644 index 0000000000..8f19dac0dd --- /dev/null +++ b/func-app/graphrag/query/input/loaders/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG Orchestartion Input Loaders.""" diff --git a/func-app/graphrag/query/input/loaders/dfs.py b/func-app/graphrag/query/input/loaders/dfs.py new file mode 100644 index 0000000000..7312963bb8 --- /dev/null +++ b/func-app/graphrag/query/input/loaders/dfs.py @@ -0,0 +1,340 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Load data from dataframes into collections of data objects.""" + +import pandas as pd + +from graphrag.model import ( + Community, + CommunityReport, + Covariate, + Document, + Entity, + Relationship, + TextUnit, +) +from graphrag.query.input.loaders.utils import ( + to_list, + to_optional_dict, + to_optional_float, + to_optional_int, + to_optional_list, + to_optional_str, + to_str, +) +from graphrag.vector_stores import BaseVectorStore, VectorStoreDocument + + +def read_entities( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + title_col: str = "title", + type_col: str | None = "type", + description_col: str | None = "description", + name_embedding_col: str | None = "name_embedding", + description_embedding_col: str | None = "description_embedding", + graph_embedding_col: str | None = "graph_embedding", + community_col: str | None = "community_ids", + text_unit_ids_col: str | None = "text_unit_ids", + document_ids_col: str | None = "document_ids", + rank_col: str | None = "degree", + attributes_cols: list[str] | None = None, +) -> list[Entity]: + """Read entities from a dataframe.""" + entities = [] + for idx, row in df.iterrows(): + entity = Entity( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + title=to_str(row, title_col), + type=to_optional_str(row, type_col), + description=to_optional_str(row, description_col), + name_embedding=to_optional_list(row, name_embedding_col, item_type=float), + description_embedding=to_optional_list( + row, description_embedding_col, item_type=float + ), + graph_embedding=to_optional_list(row, graph_embedding_col, item_type=float), + community_ids=to_optional_list(row, community_col, item_type=str), + text_unit_ids=to_optional_list(row, text_unit_ids_col), + document_ids=to_optional_list(row, document_ids_col), + rank=to_optional_int(row, rank_col), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + entities.append(entity) + return entities + + +def store_entity_semantic_embeddings( + entities: list[Entity], + vectorstore: BaseVectorStore, +) -> BaseVectorStore: + """Store entity semantic embeddings in a vectorstore.""" + documents = [ + VectorStoreDocument( + id=entity.id, + text=entity.description, + vector=entity.description_embedding, + attributes=( + {"title": entity.title, **entity.attributes} + if entity.attributes + else {"title": entity.title} + ), + ) + for entity in entities + ] + vectorstore.load_documents(documents=documents) + return vectorstore + + +def store_entity_behavior_embeddings( + entities: list[Entity], + vectorstore: BaseVectorStore, +) -> BaseVectorStore: + """Store entity behavior embeddings in a vectorstore.""" + documents = [ + VectorStoreDocument( + id=entity.id, + text=entity.description, + vector=entity.graph_embedding, + attributes=( + {"title": entity.title, **entity.attributes} + if entity.attributes + else {"title": entity.title} + ), + ) + for entity in entities + ] + vectorstore.load_documents(documents=documents) + return vectorstore + + +def read_relationships( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + source_col: str = "source", + target_col: str = "target", + description_col: str | None = "description", + description_embedding_col: str | None = "description_embedding", + weight_col: str | None = "weight", + text_unit_ids_col: str | None = "text_unit_ids", + document_ids_col: str | None = "document_ids", + attributes_cols: list[str] | None = None, +) -> list[Relationship]: + """Read relationships from a dataframe.""" + relationships = [] + for idx, row in df.iterrows(): + rel = Relationship( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + source=to_str(row, source_col), + target=to_str(row, target_col), + description=to_optional_str(row, description_col), + description_embedding=to_optional_list( + row, description_embedding_col, item_type=float + ), + weight=to_optional_float(row, weight_col), + text_unit_ids=to_optional_list(row, text_unit_ids_col, item_type=str), + document_ids=to_optional_list(row, document_ids_col, item_type=str), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + relationships.append(rel) + return relationships + + +def read_covariates( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + subject_col: str = "subject_id", + subject_type_col: str | None = "subject_type", + covariate_type_col: str | None = "covariate_type", + text_unit_ids_col: str | None = "text_unit_ids", + document_ids_col: str | None = "document_ids", + attributes_cols: list[str] | None = None, +) -> list[Covariate]: + """Read covariates from a dataframe.""" + covariates = [] + for idx, row in df.iterrows(): + cov = Covariate( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + subject_id=to_str(row, subject_col), + subject_type=( + to_str(row, subject_type_col) if subject_type_col else "entity" + ), + covariate_type=( + to_str(row, covariate_type_col) if covariate_type_col else "claim" + ), + text_unit_ids=to_optional_list(row, text_unit_ids_col, item_type=str), + document_ids=to_optional_list(row, document_ids_col, item_type=str), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + covariates.append(cov) + return covariates + + +def read_communities( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + title_col: str = "title", + level_col: str = "level", + entities_col: str | None = "entity_ids", + relationships_col: str | None = "relationship_ids", + covariates_col: str | None = "covariate_ids", + attributes_cols: list[str] | None = None, +) -> list[Community]: + """Read communities from a dataframe.""" + communities = [] + for idx, row in df.iterrows(): + comm = Community( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + title=to_str(row, title_col), + level=to_str(row, level_col), + entity_ids=to_optional_list(row, entities_col, item_type=str), + relationship_ids=to_optional_list(row, relationships_col, item_type=str), + covariate_ids=to_optional_dict( + row, covariates_col, key_type=str, value_type=str + ), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + communities.append(comm) + return communities + + +def read_community_reports( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + title_col: str = "title", + community_col: str = "community", + summary_col: str = "summary", + content_col: str = "full_content", + rank_col: str | None = "rank", + summary_embedding_col: str | None = "summary_embedding", + content_embedding_col: str | None = "full_content_embedding", + attributes_cols: list[str] | None = None, +) -> list[CommunityReport]: + """Read community reports from a dataframe.""" + reports = [] + for idx, row in df.iterrows(): + report = CommunityReport( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + title=to_str(row, title_col), + community_id=to_str(row, community_col), + summary=to_str(row, summary_col), + full_content=to_str(row, content_col), + rank=to_optional_float(row, rank_col), + summary_embedding=to_optional_list( + row, summary_embedding_col, item_type=float + ), + full_content_embedding=to_optional_list( + row, content_embedding_col, item_type=float + ), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + reports.append(report) + return reports + + +def read_text_units( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str | None = "short_id", + text_col: str = "text", + entities_col: str | None = "entity_ids", + relationships_col: str | None = "relationship_ids", + covariates_col: str | None = "covariate_ids", + tokens_col: str | None = "n_tokens", + document_ids_col: str | None = "document_ids", + embedding_col: str | None = "text_embedding", + attributes_cols: list[str] | None = None, +) -> list[TextUnit]: + """Read text units from a dataframe.""" + text_units = [] + for idx, row in df.iterrows(): + chunk = TextUnit( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + text=to_str(row, text_col), + entity_ids=to_optional_list(row, entities_col, item_type=str), + relationship_ids=to_optional_list(row, relationships_col, item_type=str), + covariate_ids=to_optional_dict( + row, covariates_col, key_type=str, value_type=str + ), + text_embedding=to_optional_list(row, embedding_col, item_type=float), # type: ignore + n_tokens=to_optional_int(row, tokens_col), + document_ids=to_optional_list(row, document_ids_col, item_type=str), + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + text_units.append(chunk) + return text_units + + +def read_documents( + df: pd.DataFrame, + id_col: str = "id", + short_id_col: str = "short_id", + title_col: str = "title", + type_col: str = "type", + summary_col: str | None = "entities", + raw_content_col: str | None = "relationships", + summary_embedding_col: str | None = "summary_embedding", + content_embedding_col: str | None = "raw_content_embedding", + text_units_col: str | None = "text_units", + attributes_cols: list[str] | None = None, +) -> list[Document]: + """Read documents from a dataframe.""" + docs = [] + for idx, row in df.iterrows(): + doc = Document( + id=to_str(row, id_col), + short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), + title=to_str(row, title_col), + type=to_str(row, type_col), + summary=to_optional_str(row, summary_col), + raw_content=to_str(row, raw_content_col), + summary_embedding=to_optional_list( + row, summary_embedding_col, item_type=float + ), + raw_content_embedding=to_optional_list( + row, content_embedding_col, item_type=float + ), + text_units=to_list(row, text_units_col, item_type=str), # type: ignore + attributes=( + {col: row.get(col) for col in attributes_cols} + if attributes_cols + else None + ), + ) + docs.append(doc) + return docs diff --git a/func-app/graphrag/query/input/loaders/utils.py b/func-app/graphrag/query/input/loaders/utils.py new file mode 100644 index 0000000000..e0fffd2467 --- /dev/null +++ b/func-app/graphrag/query/input/loaders/utils.py @@ -0,0 +1,245 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Data load utils.""" + +import numpy as np +import pandas as pd + + +def to_str(data: pd.Series, column_name: str | None) -> str: + """Convert and validate a value to a string.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + return str(data[column_name]) + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + +def to_optional_str(data: pd.Series, column_name: str | None) -> str | None: + """Convert and validate a value to an optional string.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + value = data[column_name] + if value is None: + return None + return str(data[column_name]) + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + +def to_list( + data: pd.Series, column_name: str | None, item_type: type | None = None +) -> list: + """Convert and validate a value to a list.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + value = data[column_name] + if isinstance(value, np.ndarray): + value = value.tolist() + + if not isinstance(value, list): + msg = f"value is not a list: {value} ({type(value)})" + raise ValueError(msg) + + if item_type is not None: + for v in value: + if not isinstance(v, item_type): + msg = f"list item has item that is not {item_type}: {v} ({type(v)})" + raise TypeError(msg) + return value + + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + +def to_optional_list( + data: pd.Series, column_name: str | None, item_type: type | None = None +) -> list | None: + """Convert and validate a value to an optional list.""" + if column_name is None: + return None + + if column_name in data: + value = data[column_name] # type: ignore + if value is None: + return None + + if isinstance(value, np.ndarray): + value = value.tolist() + + if not isinstance(value, list): + msg = f"value is not a list: {value} ({type(value)})" + raise ValueError(msg) + + if item_type is not None: + for v in value: + if not isinstance(v, item_type): + msg = f"list item has item that is not {item_type}: {v} ({type(v)})" + raise TypeError(msg) + return value + + return None + + +def to_int(data: pd.Series, column_name: str | None) -> int: + """Convert and validate a value to an int.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + value = data[column_name] + if isinstance(value, float): + value = int(value) + if not isinstance(value, int): + msg = f"value is not an int: {value} ({type(value)})" + raise ValueError(msg) + else: + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + return int(value) + + +def to_optional_int(data: pd.Series, column_name: str | None) -> int | None: + """Convert and validate a value to an optional int.""" + if column_name is None: + return None + + if column_name in data: + value = data[column_name] + + if value is None: + return None + + if isinstance(value, float): + value = int(value) + if not isinstance(value, int): + msg = f"value is not an int: {value} ({type(value)})" + raise ValueError(msg) + else: + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + return int(value) + + +def to_float(data: pd.Series, column_name: str | None) -> float: + """Convert and validate a value to a float.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + value = data[column_name] + if not isinstance(value, float): + msg = f"value is not a float: {value} ({type(value)})" + raise ValueError(msg) + else: + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + return float(value) + + +def to_optional_float(data: pd.Series, column_name: str | None) -> float | None: + """Convert and validate a value to an optional float.""" + if column_name is None: + return None + + if column_name in data: + value = data[column_name] + if value is None: + return None + if not isinstance(value, float): + msg = f"value is not a float: {value} ({type(value)})" + raise ValueError(msg) + else: + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + return float(value) + + +def to_dict( + data: pd.Series, + column_name: str | None, + key_type: type | None = None, + value_type: type | None = None, +) -> dict: + """Convert and validate a value to a dict.""" + if column_name is None: + msg = "Column name is None" + raise ValueError(msg) + + if column_name in data: + value = data[column_name] + if not isinstance(value, dict): + msg = f"value is not a dict: {value} ({type(value)})" + raise ValueError(msg) + + if key_type is not None: + for v in value: + if not isinstance(v, key_type): + msg = f"dict key has item that is not {key_type}: {v} ({type(v)})" + raise TypeError(msg) + + if value_type is not None: + for v in value.values(): + if not isinstance(v, value_type): + msg = ( + f"dict value has item that is not {value_type}: {v} ({type(v)})" + ) + raise TypeError(msg) + return value + + msg = f"Column {column_name} not found in data" + raise ValueError(msg) + + +def to_optional_dict( + data: pd.Series, + column_name: str | None, + key_type: type | None = None, + value_type: type | None = None, +) -> dict | None: + """Convert and validate a value to an optional dict.""" + if column_name is None: + return None + + if column_name in data: + value = data[column_name] + if value is None: + return None + if not isinstance(value, dict): + msg = f"value is not a dict: {value} ({type(value)})" + raise TypeError(msg) + + if key_type is not None: + for v in value: + if not isinstance(v, key_type): + msg = f"dict key has item that is not {key_type}: {v} ({type(v)})" + raise TypeError(msg) + + if value_type is not None: + for v in value.values(): + if not isinstance(v, value_type): + msg = ( + f"dict value has item that is not {value_type}: {v} ({type(v)})" + ) + raise TypeError(msg) + + return value + + msg = f"Column {column_name} not found in data" + raise ValueError(msg) diff --git a/func-app/graphrag/query/input/retrieval/__init__.py b/func-app/graphrag/query/input/retrieval/__init__.py new file mode 100644 index 0000000000..75c2f9f095 --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG Orchestration Input Retrieval.""" diff --git a/func-app/graphrag/query/input/retrieval/community_reports.py b/func-app/graphrag/query/input/retrieval/community_reports.py new file mode 100644 index 0000000000..bd4933f1f9 --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/community_reports.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Util functions to retrieve community reports from a collection.""" + +from typing import Any, cast + +import pandas as pd + +from graphrag.model import CommunityReport, Entity + + +def get_candidate_communities( + selected_entities: list[Entity], + community_reports: list[CommunityReport], + include_community_rank: bool = False, + use_community_summary: bool = False, +) -> pd.DataFrame: + """Get all communities that are related to selected entities.""" + selected_community_ids = [ + entity.community_ids for entity in selected_entities if entity.community_ids + ] + selected_community_ids = [ + item for sublist in selected_community_ids for item in sublist + ] + selected_reports = [ + community + for community in community_reports + if community.id in selected_community_ids + ] + return to_community_report_dataframe( + reports=selected_reports, + include_community_rank=include_community_rank, + use_community_summary=use_community_summary, + ) + + +def to_community_report_dataframe( + reports: list[CommunityReport], + include_community_rank: bool = False, + use_community_summary: bool = False, +) -> pd.DataFrame: + """Convert a list of communities to a pandas dataframe.""" + if len(reports) == 0: + return pd.DataFrame() + + # add header + header = ["id", "title"] + attribute_cols = list(reports[0].attributes.keys()) if reports[0].attributes else [] + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + header.append("summary" if use_community_summary else "content") + if include_community_rank: + header.append("rank") + + records = [] + for report in reports: + new_record = [ + report.short_id if report.short_id else "", + report.title, + *[ + str(report.attributes.get(field, "")) + if report.attributes and report.attributes.get(field) + else "" + for field in attribute_cols + ], + ] + new_record.append( + report.summary if use_community_summary else report.full_content + ) + if include_community_rank: + new_record.append(str(report.rank)) + records.append(new_record) + return pd.DataFrame(records, columns=cast(Any, header)) diff --git a/func-app/graphrag/query/input/retrieval/covariates.py b/func-app/graphrag/query/input/retrieval/covariates.py new file mode 100644 index 0000000000..1c45203d01 --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/covariates.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Util functions to retrieve covariates from a collection.""" + +from typing import Any, cast + +import pandas as pd + +from graphrag.model import Covariate, Entity + + +def get_candidate_covariates( + selected_entities: list[Entity], + covariates: list[Covariate], +) -> list[Covariate]: + """Get all covariates that are related to selected entities.""" + selected_entity_names = [entity.title for entity in selected_entities] + return [ + covariate + for covariate in covariates + if covariate.subject_id in selected_entity_names + ] + + +def to_covariate_dataframe(covariates: list[Covariate]) -> pd.DataFrame: + """Convert a list of covariates to a pandas dataframe.""" + if len(covariates) == 0: + return pd.DataFrame() + + # add header + header = ["id", "entity"] + attributes = covariates[0].attributes or {} if len(covariates) > 0 else {} + attribute_cols = list(attributes.keys()) if len(covariates) > 0 else [] + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + records = [] + for covariate in covariates: + new_record = [ + covariate.short_id if covariate.short_id else "", + covariate.subject_id, + ] + for field in attribute_cols: + field_value = ( + str(covariate.attributes.get(field)) + if covariate.attributes and covariate.attributes.get(field) + else "" + ) + new_record.append(field_value) + records.append(new_record) + return pd.DataFrame(records, columns=cast(Any, header)) diff --git a/func-app/graphrag/query/input/retrieval/entities.py b/func-app/graphrag/query/input/retrieval/entities.py new file mode 100644 index 0000000000..5465f9f59e --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/entities.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Util functions to get entities from a collection.""" + +import uuid +from collections.abc import Iterable +from typing import Any, cast + +import pandas as pd + +from graphrag.model import Entity + + +def get_entity_by_key( + entities: Iterable[Entity], key: str, value: str | int +) -> Entity | None: + """Get entity by key.""" + for entity in entities: + if isinstance(value, str) and is_valid_uuid(value): + if getattr(entity, key) == value or getattr(entity, key) == value.replace( + "-", "" + ): + return entity + else: + if getattr(entity, key) == value: + return entity + return None + + +def get_entity_by_name(entities: Iterable[Entity], entity_name: str) -> list[Entity]: + """Get entities by name.""" + return [entity for entity in entities if entity.title == entity_name] + + +def get_entity_by_attribute( + entities: Iterable[Entity], attribute_name: str, attribute_value: Any +) -> list[Entity]: + """Get entities by attribute.""" + return [ + entity + for entity in entities + if entity.attributes + and entity.attributes.get(attribute_name) == attribute_value + ] + + +def to_entity_dataframe( + entities: list[Entity], + include_entity_rank: bool = True, + rank_description: str = "number of relationships", +) -> pd.DataFrame: + """Convert a list of entities to a pandas dataframe.""" + if len(entities) == 0: + return pd.DataFrame() + header = ["id", "entity", "description"] + if include_entity_rank: + header.append(rank_description) + attribute_cols = ( + list(entities[0].attributes.keys()) if entities[0].attributes else [] + ) + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + records = [] + for entity in entities: + new_record = [ + entity.short_id if entity.short_id else "", + entity.title, + entity.description if entity.description else "", + ] + if include_entity_rank: + new_record.append(str(entity.rank)) + + for field in attribute_cols: + field_value = ( + str(entity.attributes.get(field)) + if entity.attributes and entity.attributes.get(field) + else "" + ) + new_record.append(field_value) + records.append(new_record) + return pd.DataFrame(records, columns=cast(Any, header)) + + +def is_valid_uuid(value: str) -> bool: + """Determine if a string is a valid UUID.""" + try: + uuid.UUID(str(value)) + except ValueError: + return False + else: + return True diff --git a/func-app/graphrag/query/input/retrieval/relationships.py b/func-app/graphrag/query/input/retrieval/relationships.py new file mode 100644 index 0000000000..7be258d23c --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/relationships.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Util functions to retrieve relationships from a collection.""" + +import time +from typing import Any, cast + +import pandas as pd + +from common.graph_db_client import GraphDBClient +from graphrag.model import Entity, Relationship + +from graphrag.query.input.loaders.dfs import read_relationships + +def get_relationships_from_graphdb(query:str,selected_entity_names:list[str],graphdb_client: GraphDBClient): + relationships_result=graphdb_client._client.submit( + message=query, + bindings={ + "prop_selected_entity_names": selected_entity_names, + } + ) + time.sleep(5) + print(graphdb_client.result_to_df(relationships_result)) + return read_relationships( + graphdb_client.result_to_df(relationships_result), + short_id_col="human_readable_id" + ) + +def get_in_network_relationships( + selected_entities: list[Entity], + relationships: list[Relationship], + ranking_attribute: str = "rank", + graphdb_client: GraphDBClient|None=None, +) -> list[Relationship]: + """Get all directed relationships between selected entities, sorted by ranking_attribute.""" + selected_entity_names = [entity.title for entity in selected_entities] + if not graphdb_client: + selected_relationships = [ + relationship + for relationship in relationships + if relationship.source in selected_entity_names + and relationship.target in selected_entity_names + ] + else: + selected_relationships = get_relationships_from_graphdb( + query=( + "g.E()" + ".where(inV().has('name',within(prop_selected_entity_names)))" + ".where(outV().has('name',within(prop_selected_entity_names)))" + ), + selected_entity_names=selected_entity_names, + graphdb_client=graphdb_client + ) + if len(selected_relationships) <= 1: + return selected_relationships + + # sort by ranking attribute + return sort_relationships_by_ranking_attribute( + selected_relationships, selected_entities, ranking_attribute + ) + + +def get_out_network_relationships( + selected_entities: list[Entity], + relationships: list[Relationship], + ranking_attribute: str = "rank", + graphdb_client: GraphDBClient|None=None, +) -> list[Relationship]: + """Get relationships from selected entities to other entities that are not within the selected entities, sorted by ranking_attribute.""" + selected_entity_names = [entity.title for entity in selected_entities] + if not graphdb_client: + source_relationships = [ + relationship + for relationship in relationships + if relationship.source in selected_entity_names + and relationship.target not in selected_entity_names + ] + target_relationships = [ + relationship + for relationship in relationships + if relationship.target in selected_entity_names + and relationship.source not in selected_entity_names + ] + selected_relationships = source_relationships + target_relationships + else: + selected_relationships = get_relationships_from_graphdb( + query=( + "g.E().union(" + "__.where(outV().has('name',without(prop_selected_entity_names)))" + ".where(inV().has('name',within(prop_selected_entity_names)))," + "__.where(inV().has('name',without(prop_selected_entity_names)))" + ".where(outV().has('name',within(prop_selected_entity_names)))" + ")" + ), + selected_entity_names= selected_entity_names, + graphdb_client=graphdb_client + ) + return sort_relationships_by_ranking_attribute( + selected_relationships, selected_entities, ranking_attribute + ) + + +def get_candidate_relationships( + selected_entities: list[Entity], + relationships: list[Relationship], +) -> list[Relationship]: + """Get all relationships that are associated with the selected entities.""" + selected_entity_names = [entity.title for entity in selected_entities] + return [ + relationship + for relationship in relationships + if relationship.source in selected_entity_names + or relationship.target in selected_entity_names + ] + + +def get_entities_from_relationships( + relationships: list[Relationship], entities: list[Entity] +) -> list[Entity]: + """Get all entities that are associated with the selected relationships.""" + selected_entity_names = [relationship.source for relationship in relationships] + [ + relationship.target for relationship in relationships + ] + return [entity for entity in entities if entity.title in selected_entity_names] + + +def calculate_relationship_combined_rank( + relationships: list[Relationship], + entities: list[Entity], + ranking_attribute: str = "rank", +) -> list[Relationship]: + """Calculate default rank for a relationship based on the combined rank of source and target entities.""" + entity_mappings = {entity.title: entity for entity in entities} + + for relationship in relationships: + if relationship.attributes is None: + relationship.attributes = {} + source = entity_mappings.get(relationship.source) + target = entity_mappings.get(relationship.target) + source_rank = source.rank if source and source.rank else 0 + target_rank = target.rank if target and target.rank else 0 + relationship.attributes[ranking_attribute] = source_rank + target_rank # type: ignore + return relationships + + +def sort_relationships_by_ranking_attribute( + relationships: list[Relationship], + entities: list[Entity], + ranking_attribute: str = "rank", +) -> list[Relationship]: + """ + Sort relationships by a ranking_attribute. + + If no ranking attribute exists, sort by combined rank of source and target entities. + """ + if len(relationships) == 0: + return relationships + + # sort by ranking attribute + attribute_names = ( + list(relationships[0].attributes.keys()) if relationships[0].attributes else [] + ) + if ranking_attribute in attribute_names: + relationships.sort( + key=lambda x: int(x.attributes[ranking_attribute]) if x.attributes else 0, + reverse=True, + ) + elif ranking_attribute == "weight": + relationships.sort(key=lambda x: x.weight if x.weight else 0.0, reverse=True) + else: + # ranking attribute do not exist, calculate rank = combined ranks of source and target + relationships = calculate_relationship_combined_rank( + relationships, entities, ranking_attribute + ) + relationships.sort( + key=lambda x: int(x.attributes[ranking_attribute]) if x.attributes else 0, + reverse=True, + ) + return relationships + + +def to_relationship_dataframe( + relationships: list[Relationship], include_relationship_weight: bool = True +) -> pd.DataFrame: + """Convert a list of relationships to a pandas dataframe.""" + if len(relationships) == 0: + return pd.DataFrame() + + header = ["id", "source", "target", "description"] + if include_relationship_weight: + header.append("weight") + attribute_cols = ( + list(relationships[0].attributes.keys()) if relationships[0].attributes else [] + ) + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + records = [] + for rel in relationships: + new_record = [ + rel.short_id if rel.short_id else "", + rel.source, + rel.target, + rel.description if rel.description else "", + ] + if include_relationship_weight: + new_record.append(str(rel.weight if rel.weight else "")) + for field in attribute_cols: + field_value = ( + str(rel.attributes.get(field)) + if rel.attributes and rel.attributes.get(field) + else "" + ) + new_record.append(field_value) + records.append(new_record) + return pd.DataFrame(records, columns=cast(Any, header)) diff --git a/func-app/graphrag/query/input/retrieval/text_units.py b/func-app/graphrag/query/input/retrieval/text_units.py new file mode 100644 index 0000000000..a00dc20a0a --- /dev/null +++ b/func-app/graphrag/query/input/retrieval/text_units.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Util functions to retrieve text units from a collection.""" + +from typing import Any, cast + +import pandas as pd + +from graphrag.model import Entity, TextUnit + + +def get_candidate_text_units( + selected_entities: list[Entity], + text_units: list[TextUnit], +) -> pd.DataFrame: + """Get all text units that are associated to selected entities.""" + selected_text_ids = [ + entity.text_unit_ids for entity in selected_entities if entity.text_unit_ids + ] + selected_text_ids = [item for sublist in selected_text_ids for item in sublist] + selected_text_units = [unit for unit in text_units if unit.id in selected_text_ids] + return to_text_unit_dataframe(selected_text_units) + + +def to_text_unit_dataframe(text_units: list[TextUnit]) -> pd.DataFrame: + """Convert a list of text units to a pandas dataframe.""" + if len(text_units) == 0: + return pd.DataFrame() + + # add header + header = ["id", "text"] + attribute_cols = ( + list(text_units[0].attributes.keys()) if text_units[0].attributes else [] + ) + attribute_cols = [col for col in attribute_cols if col not in header] + header.extend(attribute_cols) + + records = [] + for unit in text_units: + new_record = [ + unit.short_id, + unit.text, + *[ + str(unit.attributes.get(field, "")) + if unit.attributes and unit.attributes.get(field) + else "" + for field in attribute_cols + ], + ] + records.append(new_record) + return pd.DataFrame(records, columns=cast(Any, header)) diff --git a/func-app/graphrag/query/llm/__init__.py b/func-app/graphrag/query/llm/__init__.py new file mode 100644 index 0000000000..b8f507b138 --- /dev/null +++ b/func-app/graphrag/query/llm/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Orchestration LLM utilities.""" diff --git a/func-app/graphrag/query/llm/base.py b/func-app/graphrag/query/llm/base.py new file mode 100644 index 0000000000..228150af50 --- /dev/null +++ b/func-app/graphrag/query/llm/base.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for LLM and Embedding models.""" + +from abc import ABC, abstractmethod +from typing import Any + + +class BaseLLMCallback: + """Base class for LLM callbacks.""" + + def __init__(self): + self.response = [] + + def on_llm_new_token(self, token: str): + """Handle when a new token is generated.""" + self.response.append(token) + + +class BaseLLM(ABC): + """The Base LLM implementation.""" + + @abstractmethod + def generate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate a response.""" + + @abstractmethod + async def agenerate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate a response asynchronously.""" + + +class BaseTextEmbedding(ABC): + """The text embedding interface.""" + + @abstractmethod + def embed(self, text: str, **kwargs: Any) -> list[float]: + """Embed a text string.""" + + @abstractmethod + async def aembed(self, text: str, **kwargs: Any) -> list[float]: + """Embed a text string asynchronously.""" diff --git a/func-app/graphrag/query/llm/oai/__init__.py b/func-app/graphrag/query/llm/oai/__init__.py new file mode 100644 index 0000000000..cbb257905e --- /dev/null +++ b/func-app/graphrag/query/llm/oai/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG Orchestration OpenAI Wrappers.""" + +from .base import BaseOpenAILLM, OpenAILLMImpl, OpenAITextEmbeddingImpl +from .chat_openai import ChatOpenAI +from .embedding import OpenAIEmbedding +from .openai import OpenAI +from .typing import OPENAI_RETRY_ERROR_TYPES, OpenaiApiType + +__all__ = [ + "OPENAI_RETRY_ERROR_TYPES", + "BaseOpenAILLM", + "ChatOpenAI", + "OpenAI", + "OpenAIEmbedding", + "OpenAILLMImpl", + "OpenAITextEmbeddingImpl", + "OpenaiApiType", +] diff --git a/func-app/graphrag/query/llm/oai/base.py b/func-app/graphrag/query/llm/oai/base.py new file mode 100644 index 0000000000..6181c0b2a5 --- /dev/null +++ b/func-app/graphrag/query/llm/oai/base.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for LLM and Embedding models.""" + +from abc import ABC, abstractmethod +from collections.abc import Callable + +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI + +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.query.llm.oai.typing import OpenaiApiType +from graphrag.query.progress import ConsoleStatusReporter, StatusReporter + + +class BaseOpenAILLM(ABC): + """The Base OpenAI LLM implementation.""" + + _async_client: AsyncOpenAI | AsyncAzureOpenAI + _sync_client: OpenAI | AzureOpenAI + + def __init__(self): + self._create_openai_client() + + @abstractmethod + def _create_openai_client(self): + """Create a new synchronous and asynchronous OpenAI client instance.""" + + def set_clients( + self, + sync_client: OpenAI | AzureOpenAI, + async_client: AsyncOpenAI | AsyncAzureOpenAI, + ): + """ + Set the synchronous and asynchronous clients used for making API requests. + + Args: + sync_client (OpenAI | AzureOpenAI): The sync client object. + async_client (AsyncOpenAI | AsyncAzureOpenAI): The async client object. + """ + self._sync_client = sync_client + self._async_client = async_client + + @property + def async_client(self) -> AsyncOpenAI | AsyncAzureOpenAI | None: + """ + Get the asynchronous client used for making API requests. + + Returns + ------- + AsyncOpenAI | AsyncAzureOpenAI: The async client object. + """ + return self._async_client + + @property + def sync_client(self) -> OpenAI | AzureOpenAI | None: + """ + Get the synchronous client used for making API requests. + + Returns + ------- + AsyncOpenAI | AsyncAzureOpenAI: The async client object. + """ + return self._sync_client + + @async_client.setter + def async_client(self, client: AsyncOpenAI | AsyncAzureOpenAI): + """ + Set the asynchronous client used for making API requests. + + Args: + client (AsyncOpenAI | AsyncAzureOpenAI): The async client object. + """ + self._async_client = client + + @sync_client.setter + def sync_client(self, client: OpenAI | AzureOpenAI): + """ + Set the synchronous client used for making API requests. + + Args: + client (OpenAI | AzureOpenAI): The sync client object. + """ + self._sync_client = client + + +class OpenAILLMImpl(BaseOpenAILLM): + """Orchestration OpenAI LLM Implementation.""" + + _reporter: StatusReporter = ConsoleStatusReporter() + + def __init__( + self, + api_key: str | None = None, + azure_ad_token_provider: Callable | None = None, + deployment_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + api_type: OpenaiApiType = OpenaiApiType.OpenAI, + organization: str | None = None, + max_retries: int = 10, + request_timeout: float = 180.0, + reporter: StatusReporter | None = None, + ): + self.api_key = api_key + self.azure_ad_token_provider = azure_ad_token_provider + self.deployment_name = deployment_name + self.api_base = api_base + self.api_version = api_version + self.api_type = api_type + self.organization = organization + self.max_retries = max_retries + self.request_timeout = request_timeout + self.reporter = reporter or ConsoleStatusReporter() + + try: + # Create OpenAI sync and async clients + super().__init__() + except Exception as e: + self._reporter.error( + message="Failed to create OpenAI client", + details={self.__class__.__name__: str(e)}, + ) + raise + + def _create_openai_client(self): + """Create a new OpenAI client instance.""" + if self.api_type == OpenaiApiType.AzureOpenAI: + if self.api_base is None: + msg = "api_base is required for Azure OpenAI" + raise ValueError(msg) + + sync_client = AzureOpenAI( + api_key=self.api_key, + azure_ad_token_provider=self.azure_ad_token_provider, + organization=self.organization, + # Azure-Specifics + api_version=self.api_version, + azure_endpoint=self.api_base, + azure_deployment=self.deployment_name, + # Retry Configuration + timeout=self.request_timeout, + max_retries=self.max_retries, + ) + + async_client = AsyncAzureOpenAI( + api_key=self.api_key, + azure_ad_token_provider=self.azure_ad_token_provider, + organization=self.organization, + # Azure-Specifics + api_version=self.api_version, + azure_endpoint=self.api_base, + azure_deployment=self.deployment_name, + # Retry Configuration + timeout=self.request_timeout, + max_retries=self.max_retries, + ) + self.set_clients(sync_client=sync_client, async_client=async_client) + + else: + sync_client = OpenAI( + api_key=self.api_key, + base_url=self.api_base, + organization=self.organization, + # Retry Configuration + timeout=self.request_timeout, + max_retries=self.max_retries, + ) + + async_client = AsyncOpenAI( + api_key=self.api_key, + base_url=self.api_base, + organization=self.organization, + # Retry Configuration + timeout=self.request_timeout, + max_retries=self.max_retries, + ) + self.set_clients(sync_client=sync_client, async_client=async_client) + + +class OpenAITextEmbeddingImpl(BaseTextEmbedding): + """Orchestration OpenAI Text Embedding Implementation.""" + + _reporter: StatusReporter | None = None + + def _create_openai_client(self, api_type: OpenaiApiType): + """Create a new synchronous and asynchronous OpenAI client instance.""" diff --git a/func-app/graphrag/query/llm/oai/chat_openai.py b/func-app/graphrag/query/llm/oai/chat_openai.py new file mode 100644 index 0000000000..92a9755b10 --- /dev/null +++ b/func-app/graphrag/query/llm/oai/chat_openai.py @@ -0,0 +1,206 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Chat-based OpenAI LLM implementation.""" + +from collections.abc import Callable +from typing import Any + +from tenacity import ( + AsyncRetrying, + RetryError, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from graphrag.query.llm.base import BaseLLM, BaseLLMCallback +from graphrag.query.llm.oai.base import OpenAILLMImpl +from graphrag.query.llm.oai.typing import ( + OPENAI_RETRY_ERROR_TYPES, + OpenaiApiType, +) +from graphrag.query.progress import StatusReporter + +_MODEL_REQUIRED_MSG = "model is required" + + +class ChatOpenAI(BaseLLM, OpenAILLMImpl): + """Wrapper for OpenAI ChatCompletion models.""" + + def __init__( + self, + api_key: str | None = None, + model: str | None = None, + azure_ad_token_provider: Callable | None = None, + deployment_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + api_type: OpenaiApiType = OpenaiApiType.OpenAI, + organization: str | None = None, + max_retries: int = 10, + request_timeout: float = 180.0, + retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore + reporter: StatusReporter | None = None, + ): + OpenAILLMImpl.__init__( + self=self, + api_key=api_key, + azure_ad_token_provider=azure_ad_token_provider, + deployment_name=deployment_name, + api_base=api_base, + api_version=api_version, + api_type=api_type, # type: ignore + organization=organization, + max_retries=max_retries, + request_timeout=request_timeout, + reporter=reporter, + ) + self.model = model + self.retry_error_types = retry_error_types + + def generate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate text.""" + try: + retryer = Retrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + for attempt in retryer: + with attempt: + return self._generate( + messages=messages, + streaming=streaming, + callbacks=callbacks, + **kwargs, + ) + except RetryError as e: + self._reporter.error( + message="Error at generate()", details={self.__class__.__name__: str(e)} + ) + return "" + else: + # TODO: why not just throw in this case? + return "" + + async def agenerate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate text asynchronously.""" + try: + retryer = AsyncRetrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), # type: ignore + ) + async for attempt in retryer: + with attempt: + return await self._agenerate( + messages=messages, + streaming=streaming, + callbacks=callbacks, + **kwargs, + ) + except RetryError as e: + self._reporter.error(f"Error at agenerate(): {e}") + return "" + else: + # TODO: why not just throw in this case? + return "" + + def _generate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + model = self.model + if not model: + raise ValueError(_MODEL_REQUIRED_MSG) + response = self.sync_client.chat.completions.create( # type: ignore + model=model, + messages=messages, # type: ignore + stream=streaming, + **kwargs, + ) # type: ignore + if streaming: + full_response = "" + while True: + try: + chunk = response.__next__() # type: ignore + if not chunk or not chunk.choices: + continue + + delta = ( + chunk.choices[0].delta.content + if chunk.choices[0].delta and chunk.choices[0].delta.content + else "" + ) # type: ignore + + full_response += delta + if callbacks: + for callback in callbacks: + callback.on_llm_new_token(delta) + if chunk.choices[0].finish_reason == "stop": # type: ignore + break + except StopIteration: + break + return full_response + return response.choices[0].message.content or "" # type: ignore + + async def _agenerate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + model = self.model + if not model: + raise ValueError(_MODEL_REQUIRED_MSG) + response = await self.async_client.chat.completions.create( # type: ignore + model=model, + messages=messages, # type: ignore + stream=streaming, + **kwargs, + ) + if streaming: + full_response = "" + while True: + try: + chunk = await response.__anext__() # type: ignore + if not chunk or not chunk.choices: + continue + + delta = ( + chunk.choices[0].delta.content + if chunk.choices[0].delta and chunk.choices[0].delta.content + else "" + ) # type: ignore + + full_response += delta + if callbacks: + for callback in callbacks: + callback.on_llm_new_token(delta) + if chunk.choices[0].finish_reason == "stop": # type: ignore + break + except StopIteration: + break + return full_response + + return response.choices[0].message.content or "" # type: ignore diff --git a/func-app/graphrag/query/llm/oai/embedding.py b/func-app/graphrag/query/llm/oai/embedding.py new file mode 100644 index 0000000000..f40372dbce --- /dev/null +++ b/func-app/graphrag/query/llm/oai/embedding.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""OpenAI Embedding model implementation.""" + +import asyncio +from collections.abc import Callable +from typing import Any + +import numpy as np +import tiktoken +from tenacity import ( + AsyncRetrying, + RetryError, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.query.llm.oai.base import OpenAILLMImpl +from graphrag.query.llm.oai.typing import ( + OPENAI_RETRY_ERROR_TYPES, + OpenaiApiType, +) +from graphrag.query.llm.text_utils import chunk_text +from graphrag.query.progress import StatusReporter + + +class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl): + """Wrapper for OpenAI Embedding models.""" + + def __init__( + self, + api_key: str | None = None, + azure_ad_token_provider: Callable | None = None, + model: str = "text-embedding-3-small", + deployment_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + api_type: OpenaiApiType = OpenaiApiType.OpenAI, + organization: str | None = None, + encoding_name: str = "cl100k_base", + max_tokens: int = 8191, + max_retries: int = 10, + request_timeout: float = 180.0, + retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore + reporter: StatusReporter | None = None, + ): + OpenAILLMImpl.__init__( + self=self, + api_key=api_key, + azure_ad_token_provider=azure_ad_token_provider, + deployment_name=deployment_name, + api_base=api_base, + api_version=api_version, + api_type=api_type, # type: ignore + organization=organization, + max_retries=max_retries, + request_timeout=request_timeout, + reporter=reporter, + ) + + self.model = model + self.encoding_name = encoding_name + self.max_tokens = max_tokens + self.token_encoder = tiktoken.get_encoding(self.encoding_name) + self.retry_error_types = retry_error_types + + def embed(self, text: str, **kwargs: Any) -> list[float]: + """ + Embed text using OpenAI Embedding's sync function. + + For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average. + Please refer to: https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb + """ + token_chunks = chunk_text( + text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens + ) + chunk_embeddings = [] + chunk_lens = [] + for chunk in token_chunks: + try: + embedding, chunk_len = self._embed_with_retry(chunk, **kwargs) + chunk_embeddings.append(embedding) + chunk_lens.append(chunk_len) + # TODO: catch a more specific exception + except Exception as e: # noqa BLE001 + self._reporter.error( + message="Error embedding chunk", + details={self.__class__.__name__: str(e)}, + ) + + continue + chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) + chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) + return chunk_embeddings.tolist() + + async def aembed(self, text: str, **kwargs: Any) -> list[float]: + """ + Embed text using OpenAI Embedding's async function. + + For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average. + """ + token_chunks = chunk_text( + text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens + ) + chunk_embeddings = [] + chunk_lens = [] + embedding_results = await asyncio.gather(*[ + self._aembed_with_retry(chunk, **kwargs) for chunk in token_chunks + ]) + embedding_results = [result for result in embedding_results if result[0]] + chunk_embeddings = [result[0] for result in embedding_results] + chunk_lens = [result[1] for result in embedding_results] + chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) # type: ignore + chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) + return chunk_embeddings.tolist() + + def _embed_with_retry( + self, text: str | tuple, **kwargs: Any + ) -> tuple[list[float], int]: + try: + retryer = Retrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + for attempt in retryer: + with attempt: + embedding = ( + self.sync_client.embeddings.create( # type: ignore + input=text, + model=self.model, + **kwargs, # type: ignore + ) + .data[0] + .embedding + or [] + ) + return (embedding, len(text)) + except RetryError as e: + self._reporter.error( + message="Error at embed_with_retry()", + details={self.__class__.__name__: str(e)}, + ) + return ([], 0) + else: + # TODO: why not just throw in this case? + return ([], 0) + + async def _aembed_with_retry( + self, text: str | tuple, **kwargs: Any + ) -> tuple[list[float], int]: + try: + retryer = AsyncRetrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + async for attempt in retryer: + with attempt: + embedding = ( + await self.async_client.embeddings.create( # type: ignore + input=text, + model=self.model, + **kwargs, # type: ignore + ) + ).data[0].embedding or [] + return (embedding, len(text)) + except RetryError as e: + self._reporter.error( + message="Error at embed_with_retry()", + details={self.__class__.__name__: str(e)}, + ) + return ([], 0) + else: + # TODO: why not just throw in this case? + return ([], 0) diff --git a/func-app/graphrag/query/llm/oai/openai.py b/func-app/graphrag/query/llm/oai/openai.py new file mode 100644 index 0000000000..76bb5fe52c --- /dev/null +++ b/func-app/graphrag/query/llm/oai/openai.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""OpenAI Wrappers for Orchestration.""" + +import logging +from typing import Any + +from tenacity import ( + AsyncRetrying, + RetryError, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from graphrag.query.llm.base import BaseLLMCallback +from graphrag.query.llm.oai.base import OpenAILLMImpl +from graphrag.query.llm.oai.typing import ( + OPENAI_RETRY_ERROR_TYPES, + OpenaiApiType, +) + +log = logging.getLogger(__name__) + + +class OpenAI(OpenAILLMImpl): + """Wrapper for OpenAI Completion models.""" + + def __init__( + self, + api_key: str, + model: str, + deployment_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + api_type: OpenaiApiType = OpenaiApiType.OpenAI, + organization: str | None = None, + max_retries: int = 10, + retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore + ): + self.api_key = api_key + self.model = model + self.deployment_name = deployment_name + self.api_base = api_base + self.api_version = api_version + self.api_type = api_type + self.organization = organization + self.max_retries = max_retries + self.retry_error_types = retry_error_types + + def generate( + self, + messages: str | list[str], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate text.""" + try: + retryer = Retrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + for attempt in retryer: + with attempt: + return self._generate( + messages=messages, + streaming=streaming, + callbacks=callbacks, + **kwargs, + ) + except RetryError: + log.exception("RetryError at generate(): %s") + return "" + else: + # TODO: why not just throw in this case? + return "" + + async def agenerate( + self, + messages: str | list[str], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate Text Asynchronously.""" + try: + retryer = AsyncRetrying( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(self.retry_error_types), + ) + async for attempt in retryer: + with attempt: + return await self._agenerate( + messages=messages, + streaming=streaming, + callbacks=callbacks, + **kwargs, + ) + except RetryError: + log.exception("Error at agenerate()") + return "" + else: + # TODO: why not just throw in this case? + return "" + + def _generate( + self, + messages: str | list[str], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + response = self.sync_client.chat.completions.create( # type: ignore + model=self.model, + messages=messages, # type: ignore + stream=streaming, + **kwargs, + ) # type: ignore + if streaming: + full_response = "" + while True: + try: + chunk = response.__next__() # type: ignore + if not chunk or not chunk.choices: + continue + + delta = ( + chunk.choices[0].delta.content + if chunk.choices[0].delta and chunk.choices[0].delta.content + else "" + ) # type: ignore + + full_response += delta + if callbacks: + for callback in callbacks: + callback.on_llm_new_token(delta) + if chunk.choices[0].finish_reason == "stop": # type: ignore + break + except StopIteration: + break + return full_response + return response.choices[0].message.content or "" # type: ignore + + async def _agenerate( + self, + messages: str | list[str], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + response = await self.async_client.chat.completions.create( # type: ignore + model=self.model, + messages=messages, # type: ignore + stream=streaming, + **kwargs, + ) + if streaming: + full_response = "" + while True: + try: + chunk = await response.__anext__() # type: ignore + if not chunk or not chunk.choices: + continue + + delta = ( + chunk.choices[0].delta.content + if chunk.choices[0].delta and chunk.choices[0].delta.content + else "" + ) # type: ignore + + full_response += delta + if callbacks: + for callback in callbacks: + callback.on_llm_new_token(delta) + if chunk.choices[0].finish_reason == "stop": # type: ignore + break + except StopIteration: + break + return full_response + return response.choices[0].message.content or "" # type: ignore diff --git a/func-app/graphrag/query/llm/oai/typing.py b/func-app/graphrag/query/llm/oai/typing.py new file mode 100644 index 0000000000..399a82f699 --- /dev/null +++ b/func-app/graphrag/query/llm/oai/typing.py @@ -0,0 +1,23 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""OpenAI wrapper options.""" + +from enum import Enum +from typing import Any, cast + +import openai + +OPENAI_RETRY_ERROR_TYPES = ( + # TODO: update these when we update to OpenAI 1+ library + cast(Any, openai).RateLimitError, + cast(Any, openai).APIConnectionError, + # TODO: replace with comparable OpenAI 1+ error +) + + +class OpenaiApiType(str, Enum): + """The OpenAI Flavor.""" + + OpenAI = "openai" + AzureOpenAI = "azure" diff --git a/func-app/graphrag/query/llm/text_utils.py b/func-app/graphrag/query/llm/text_utils.py new file mode 100644 index 0000000000..d60e630488 --- /dev/null +++ b/func-app/graphrag/query/llm/text_utils.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Text Utilities for LLM.""" + +from collections.abc import Iterator +from itertools import islice + +import tiktoken + + +def num_tokens(text: str, token_encoder: tiktoken.Encoding | None = None) -> int: + """Return the number of tokens in the given text.""" + if token_encoder is None: + token_encoder = tiktoken.get_encoding("cl100k_base") + return len(token_encoder.encode(text)) # type: ignore + + +def batched(iterable: Iterator, n: int): + """ + Batch data into tuples of length n. The last batch may be shorter. + + Taken from Python's cookbook: https://docs.python.org/3/library/itertools.html#itertools.batched + """ + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + value_error = "n must be at least one" + raise ValueError(value_error) + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch + + +def chunk_text( + text: str, max_tokens: int, token_encoder: tiktoken.Encoding | None = None +): + """Chunk text by token length.""" + if token_encoder is None: + token_encoder = tiktoken.get_encoding("cl100k_base") + tokens = token_encoder.encode(text) # type: ignore + chunk_iterator = batched(iter(tokens), max_tokens) + yield from chunk_iterator diff --git a/func-app/graphrag/query/progress.py b/func-app/graphrag/query/progress.py new file mode 100644 index 0000000000..ad5bcee734 --- /dev/null +++ b/func-app/graphrag/query/progress.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Status Reporter for orchestration.""" + +from abc import ABCMeta, abstractmethod +from typing import Any + + +class StatusReporter(metaclass=ABCMeta): + """Provides a way to report status updates from the pipeline.""" + + @abstractmethod + def error(self, message: str, details: dict[str, Any] | None = None): + """Report an error.""" + + @abstractmethod + def warning(self, message: str, details: dict[str, Any] | None = None): + """Report a warning.""" + + @abstractmethod + def log(self, message: str, details: dict[str, Any] | None = None): + """Report a log.""" + + +class ConsoleStatusReporter(StatusReporter): + """A reporter that writes to a console.""" + + def error(self, message: str, details: dict[str, Any] | None = None): + """Report an error.""" + print(message, details) # noqa T201 + + def warning(self, message: str, details: dict[str, Any] | None = None): + """Report a warning.""" + _print_warning(message) + + def log(self, message: str, details: dict[str, Any] | None = None): + """Report a log.""" + print(message, details) # noqa T201 + + +def _print_warning(skk): + print(f"\033[93m {skk}\033[00m") # noqa T201 diff --git a/func-app/graphrag/query/question_gen/__init__.py b/func-app/graphrag/query/question_gen/__init__.py new file mode 100644 index 0000000000..d7329277c2 --- /dev/null +++ b/func-app/graphrag/query/question_gen/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Question Generation Module.""" diff --git a/func-app/graphrag/query/question_gen/base.py b/func-app/graphrag/query/question_gen/base.py new file mode 100644 index 0000000000..959b63d791 --- /dev/null +++ b/func-app/graphrag/query/question_gen/base.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for generating questions based on previously asked questions and most recent context data.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +import tiktoken + +from graphrag.query.context_builder.builders import ( + GlobalContextBuilder, + LocalContextBuilder, +) +from graphrag.query.llm.base import BaseLLM + + +@dataclass +class QuestionResult: + """A Structured Question Result.""" + + response: list[str] + context_data: str | dict[str, Any] + completion_time: float + llm_calls: int + prompt_tokens: int + + +class BaseQuestionGen(ABC): + """The Base Question Gen implementation.""" + + def __init__( + self, + llm: BaseLLM, + context_builder: GlobalContextBuilder | LocalContextBuilder, + token_encoder: tiktoken.Encoding | None = None, + llm_params: dict[str, Any] | None = None, + context_builder_params: dict[str, Any] | None = None, + ): + self.llm = llm + self.context_builder = context_builder + self.token_encoder = token_encoder + self.llm_params = llm_params or {} + self.context_builder_params = context_builder_params or {} + + @abstractmethod + def generate( + self, + question_history: list[str], + context_data: str | None, + question_count: int, + **kwargs, + ) -> QuestionResult: + """Generate questions.""" + + @abstractmethod + async def agenerate( + self, + question_history: list[str], + context_data: str | None, + question_count: int, + **kwargs, + ) -> QuestionResult: + """Generate questions asynchronously.""" diff --git a/func-app/graphrag/query/question_gen/local_gen.py b/func-app/graphrag/query/question_gen/local_gen.py new file mode 100644 index 0000000000..ca703a66e3 --- /dev/null +++ b/func-app/graphrag/query/question_gen/local_gen.py @@ -0,0 +1,194 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Local question generation.""" + +import logging +import time +from typing import Any + +import tiktoken + +from graphrag.query.context_builder.builders import LocalContextBuilder +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.llm.base import BaseLLM, BaseLLMCallback +from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult +from graphrag.query.question_gen.system_prompt import QUESTION_SYSTEM_PROMPT + +log = logging.getLogger(__name__) + + +class LocalQuestionGen(BaseQuestionGen): + """Search orchestration for global search mode.""" + + def __init__( + self, + llm: BaseLLM, + context_builder: LocalContextBuilder, + token_encoder: tiktoken.Encoding | None = None, + system_prompt: str = QUESTION_SYSTEM_PROMPT, + callbacks: list[BaseLLMCallback] | None = None, + llm_params: dict[str, Any] | None = None, + context_builder_params: dict[str, Any] | None = None, + ): + super().__init__( + llm=llm, + context_builder=context_builder, + token_encoder=token_encoder, + llm_params=llm_params, + context_builder_params=context_builder_params, + ) + self.system_prompt = system_prompt + self.callbacks = callbacks + + async def agenerate( + self, + question_history: list[str], + context_data: str | None, + question_count: int, + **kwargs, + ) -> QuestionResult: + """ + Generate a question based on the question history and context data. + + If context data is not provided, it will be generated by the local context builder + """ + start_time = time.time() + + if len(question_history) == 0: + question_text = "" + conversation_history = None + else: + # construct current query and conversation history + question_text = question_history[-1] + history = [ + {"role": "user", "content": query} for query in question_history[:-1] + ] + conversation_history = ConversationHistory.from_list(history) + + if context_data is None: + # generate context data based on the question history + context_data, context_records = self.context_builder.build_context( + query=question_text, + conversation_history=conversation_history, + **kwargs, + **self.context_builder_params, + ) # type: ignore + else: + context_records = {"context_data": context_data} + log.info("GENERATE QUESTION: %s. LAST QUESTION: %s", start_time, question_text) + system_prompt = "" + try: + system_prompt = self.system_prompt.format( + context_data=context_data, question_count=question_count + ) + question_messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question_text}, + ] + + response = await self.llm.agenerate( + messages=question_messages, + streaming=True, + callbacks=self.callbacks, + **self.llm_params, + ) + + return QuestionResult( + response=response.split("\n"), + context_data={ + "question_context": question_text, + **context_records, + }, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(system_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in generating question") + return QuestionResult( + response=[], + context_data=context_records, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(system_prompt, self.token_encoder), + ) + + def generate( + self, + question_history: list[str], + context_data: str | None, + question_count: int, + **kwargs, + ) -> QuestionResult: + """ + Generate a question based on the question history and context data. + + If context data is not provided, it will be generated by the local context builder + """ + start_time = time.time() + if len(question_history) == 0: + question_text = "" + conversation_history = None + else: + # construct current query and conversation history + question_text = question_history[-1] + history = [ + {"role": "user", "content": query} for query in question_history[:-1] + ] + conversation_history = ConversationHistory.from_list(history) + + if context_data is None: + # generate context data based on the question history + context_data, context_records = self.context_builder.build_context( + query=question_text, + conversation_history=conversation_history, + **kwargs, + **self.context_builder_params, + ) # type: ignore + else: + context_records = {"context_data": context_data} + log.info( + "GENERATE QUESTION: %s. QUESTION HISTORY: %s", start_time, question_text + ) + system_prompt = "" + try: + system_prompt = self.system_prompt.format( + context_data=context_data, question_count=question_count + ) + question_messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question_text}, + ] + + response = self.llm.generate( + messages=question_messages, + streaming=True, + callbacks=self.callbacks, + **self.llm_params, + ) + + return QuestionResult( + response=response.split("\n"), + context_data={ + "question_context": question_text, + **context_records, + }, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(system_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in generating questions") + return QuestionResult( + response=[], + context_data=context_records, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(system_prompt, self.token_encoder), + ) diff --git a/func-app/graphrag/query/question_gen/system_prompt.py b/func-app/graphrag/query/question_gen/system_prompt.py new file mode 100644 index 0000000000..904ede2435 --- /dev/null +++ b/func-app/graphrag/query/question_gen/system_prompt.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Question Generation system prompts.""" + +QUESTION_SYSTEM_PROMPT = """ +---Role--- + +You are a helpful assistant generating a bulleted list of {question_count} questions about data in the tables provided. + + +---Data tables--- + +{context_data} + + +---Goal--- + +Given a series of example questions provided by the user, generate a bulleted list of {question_count} candidates for the next question. Use - marks as bullet points. + +These candidate questions should represent the most important or urgent information content or themes in the data tables. + +The candidate questions should be answerable using the data tables provided, but should not mention any specific data fields or data tables in the question text. + +If the user's questions reference several named entities, then each candidate question should reference all named entities. + +---Example questions--- +""" diff --git a/func-app/graphrag/query/structured_search/__init__.py b/func-app/graphrag/query/structured_search/__init__.py new file mode 100644 index 0000000000..b41baaf340 --- /dev/null +++ b/func-app/graphrag/query/structured_search/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Structured Search package.""" diff --git a/func-app/graphrag/query/structured_search/base.py b/func-app/graphrag/query/structured_search/base.py new file mode 100644 index 0000000000..6dd02485f8 --- /dev/null +++ b/func-app/graphrag/query/structured_search/base.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for search algos.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +import pandas as pd +import tiktoken + +from graphrag.query.context_builder.builders import ( + GlobalContextBuilder, + LocalContextBuilder, +) +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.llm.base import BaseLLM + + +@dataclass +class SearchResult: + """A Structured Search Result.""" + + response: str | dict[str, Any] | list[dict[str, Any]] + context_data: str | list[pd.DataFrame] | dict[str, pd.DataFrame] + # actual text strings that are in the context window, built from context_data + context_text: str | list[str] | dict[str, str] + completion_time: float + llm_calls: int + prompt_tokens: int + + +class BaseSearch(ABC): + """The Base Search implementation.""" + + def __init__( + self, + llm: BaseLLM, + context_builder: GlobalContextBuilder | LocalContextBuilder, + token_encoder: tiktoken.Encoding | None = None, + llm_params: dict[str, Any] | None = None, + context_builder_params: dict[str, Any] | None = None, + ): + self.llm = llm + self.context_builder = context_builder + self.token_encoder = token_encoder + self.llm_params = llm_params or {} + self.context_builder_params = context_builder_params or {} + + @abstractmethod + def search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Search for the given query.""" + + @abstractmethod + async def asearch( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Search for the given query asynchronously.""" diff --git a/func-app/graphrag/query/structured_search/global_search/__init__.py b/func-app/graphrag/query/structured_search/global_search/__init__.py new file mode 100644 index 0000000000..ba73b60900 --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GlobalSearch module.""" diff --git a/func-app/graphrag/query/structured_search/global_search/callbacks.py b/func-app/graphrag/query/structured_search/global_search/callbacks.py new file mode 100644 index 0000000000..f48bb79b82 --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/callbacks.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GlobalSearch LLM Callbacks.""" + +from graphrag.query.llm.base import BaseLLMCallback +from graphrag.query.structured_search.base import SearchResult + + +class GlobalSearchLLMCallback(BaseLLMCallback): + """GlobalSearch LLM Callbacks.""" + + def __init__(self): + super().__init__() + self.map_response_contexts = [] + self.map_response_outputs = [] + + def on_map_response_start(self, map_response_contexts: list[str]): + """Handle the start of map response.""" + self.map_response_contexts = map_response_contexts + + def on_map_response_end(self, map_response_outputs: list[SearchResult]): + """Handle the end of map response.""" + self.map_response_outputs = map_response_outputs diff --git a/func-app/graphrag/query/structured_search/global_search/community_context.py b/func-app/graphrag/query/structured_search/global_search/community_context.py new file mode 100644 index 0000000000..d63320c85b --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/community_context.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Contains algorithms to build context data for global search prompt.""" + +from typing import Any + +import pandas as pd +import tiktoken + +from graphrag.model import CommunityReport, Entity +from graphrag.query.context_builder.community_context import ( + build_community_context, +) +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.structured_search.base import GlobalContextBuilder + + +class GlobalCommunityContext(GlobalContextBuilder): + """GlobalSearch community context builder.""" + + def __init__( + self, + community_reports: list[CommunityReport], + entities: list[Entity] | None = None, + token_encoder: tiktoken.Encoding | None = None, + random_state: int = 86, + ): + self.community_reports = community_reports + self.entities = entities + self.token_encoder = token_encoder + self.random_state = random_state + + def build_context( + self, + conversation_history: ConversationHistory | None = None, + use_community_summary: bool = True, + column_delimiter: str = "|", + shuffle_data: bool = True, + include_community_rank: bool = False, + min_community_rank: int = 0, + community_rank_name: str = "rank", + include_community_weight: bool = True, + community_weight_name: str = "occurrence", + normalize_community_weight: bool = True, + max_tokens: int = 8000, + context_name: str = "Reports", + conversation_history_user_turns_only: bool = True, + conversation_history_max_turns: int | None = 5, + **kwargs: Any, + ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: + """Prepare batches of community report data table as context data for global search.""" + conversation_history_context = "" + final_context_data = {} + if conversation_history: + # build conversation history context + ( + conversation_history_context, + conversation_history_context_data, + ) = conversation_history.build_context( + include_user_turns_only=conversation_history_user_turns_only, + max_qa_turns=conversation_history_max_turns, + column_delimiter=column_delimiter, + max_tokens=max_tokens, + recency_bias=False, + ) + if conversation_history_context != "": + final_context_data = conversation_history_context_data + + community_context, community_context_data = build_community_context( + community_reports=self.community_reports, + entities=self.entities, + token_encoder=self.token_encoder, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + shuffle_data=shuffle_data, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + community_rank_name=community_rank_name, + include_community_weight=include_community_weight, + community_weight_name=community_weight_name, + normalize_community_weight=normalize_community_weight, + max_tokens=max_tokens, + single_batch=False, + context_name=context_name, + random_state=self.random_state, + ) + if isinstance(community_context, list): + final_context = [ + f"{conversation_history_context}\n\n{context}" + for context in community_context + ] + else: + final_context = f"{conversation_history_context}\n\n{community_context}" + + final_context_data.update(community_context_data) + return (final_context, final_context_data) diff --git a/func-app/graphrag/query/structured_search/global_search/map_system_prompt.py b/func-app/graphrag/query/structured_search/global_search/map_system_prompt.py new file mode 100644 index 0000000000..db1a649df3 --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/map_system_prompt.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""System prompts for global search.""" + +MAP_SYSTEM_PROMPT = """ +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + + +---Data tables--- + +{context_data} + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} +""" diff --git a/func-app/graphrag/query/structured_search/global_search/reduce_system_prompt.py b/func-app/graphrag/query/structured_search/global_search/reduce_system_prompt.py new file mode 100644 index 0000000000..701717817c --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/reduce_system_prompt.py @@ -0,0 +1,88 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Global Search system prompts.""" + +REDUCE_SYSTEM_PROMPT = """ +---Role--- + +You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + + +---Analyst Reports--- + +{report_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. +""" + +NO_DATA_ANSWER = ( + "I am sorry but I am unable to answer this question given the provided data." +) + +GENERAL_KNOWLEDGE_INSTRUCTION = """ +The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example: +"This is an example sentence supported by real-world knowledge [LLM: verify]." +""" diff --git a/func-app/graphrag/query/structured_search/global_search/search.py b/func-app/graphrag/query/structured_search/global_search/search.py new file mode 100644 index 0000000000..12dc45fe7a --- /dev/null +++ b/func-app/graphrag/query/structured_search/global_search/search.py @@ -0,0 +1,359 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The GlobalSearch Implementation.""" + +import asyncio +import json +import logging +import time +from dataclasses import dataclass +from typing import Any + +import pandas as pd +import tiktoken + +from graphrag.llm.openai.utils import try_parse_json_object +from graphrag.query.context_builder.builders import GlobalContextBuilder +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.llm.base import BaseLLM +from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.structured_search.base import BaseSearch, SearchResult +from graphrag.query.structured_search.global_search.callbacks import ( + GlobalSearchLLMCallback, +) +from graphrag.query.structured_search.global_search.map_system_prompt import ( + MAP_SYSTEM_PROMPT, +) +from graphrag.query.structured_search.global_search.reduce_system_prompt import ( + GENERAL_KNOWLEDGE_INSTRUCTION, + NO_DATA_ANSWER, + REDUCE_SYSTEM_PROMPT, +) + +DEFAULT_MAP_LLM_PARAMS = { + "max_tokens": 1000, + "temperature": 0.0, +} + +DEFAULT_REDUCE_LLM_PARAMS = { + "max_tokens": 2000, + "temperature": 0.0, +} + +log = logging.getLogger(__name__) + + +@dataclass +class GlobalSearchResult(SearchResult): + """A GlobalSearch result.""" + + map_responses: list[SearchResult] + reduce_context_data: str | list[pd.DataFrame] | dict[str, pd.DataFrame] + reduce_context_text: str | list[str] | dict[str, str] + + +class GlobalSearch(BaseSearch): + """Search orchestration for global search mode.""" + + def __init__( + self, + llm: BaseLLM, + context_builder: GlobalContextBuilder, + token_encoder: tiktoken.Encoding | None = None, + map_system_prompt: str = MAP_SYSTEM_PROMPT, + reduce_system_prompt: str = REDUCE_SYSTEM_PROMPT, + response_type: str = "multiple paragraphs", + allow_general_knowledge: bool = False, + general_knowledge_inclusion_prompt: str = GENERAL_KNOWLEDGE_INSTRUCTION, + json_mode: bool = True, + callbacks: list[GlobalSearchLLMCallback] | None = None, + max_data_tokens: int = 8000, + map_llm_params: dict[str, Any] = DEFAULT_MAP_LLM_PARAMS, + reduce_llm_params: dict[str, Any] = DEFAULT_REDUCE_LLM_PARAMS, + context_builder_params: dict[str, Any] | None = None, + concurrent_coroutines: int = 32, + ): + super().__init__( + llm=llm, + context_builder=context_builder, + token_encoder=token_encoder, + context_builder_params=context_builder_params, + ) + self.map_system_prompt = map_system_prompt + self.reduce_system_prompt = reduce_system_prompt + self.response_type = response_type + self.allow_general_knowledge = allow_general_knowledge + self.general_knowledge_inclusion_prompt = general_knowledge_inclusion_prompt + self.callbacks = callbacks + self.max_data_tokens = max_data_tokens + + self.map_llm_params = map_llm_params + self.reduce_llm_params = reduce_llm_params + if json_mode: + self.map_llm_params["response_format"] = {"type": "json_object"} + else: + # remove response_format key if json_mode is False + self.map_llm_params.pop("response_format", None) + + self.semaphore = asyncio.Semaphore(concurrent_coroutines) + + async def asearch( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs: Any, + ) -> GlobalSearchResult: + """ + Perform a global search. + + Global search mode includes two steps: + + - Step 1: Run parallel LLM calls on communities' short summaries to generate answer for each batch + - Step 2: Combine the answers from step 2 to generate the final answer + """ + # Step 1: Generate answers for each batch of community short summaries + start_time = time.time() + context_chunks, context_records = self.context_builder.build_context( + conversation_history=conversation_history, **self.context_builder_params + ) + + if self.callbacks: + for callback in self.callbacks: + callback.on_map_response_start(context_chunks) # type: ignore + map_responses = await asyncio.gather(*[ + self._map_response_single_batch( + context_data=data, query=query, **self.map_llm_params + ) + for data in context_chunks + ]) + if self.callbacks: + for callback in self.callbacks: + callback.on_map_response_end(map_responses) + map_llm_calls = sum(response.llm_calls for response in map_responses) + map_prompt_tokens = sum(response.prompt_tokens for response in map_responses) + + # Step 2: Combine the intermediate answers from step 2 to generate the final answer + reduce_response = await self._reduce_response( + map_responses=map_responses, + query=query, + **self.reduce_llm_params, + ) + + return GlobalSearchResult( + response=reduce_response.response, + context_data=context_records, + context_text=context_chunks, + map_responses=map_responses, + reduce_context_data=reduce_response.context_data, + reduce_context_text=reduce_response.context_text, + completion_time=time.time() - start_time, + llm_calls=map_llm_calls + reduce_response.llm_calls, + prompt_tokens=map_prompt_tokens + reduce_response.prompt_tokens, + ) + + def search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs: Any, + ) -> GlobalSearchResult: + """Perform a global search synchronously.""" + return asyncio.run(self.asearch(query, conversation_history)) + + async def _map_response_single_batch( + self, + context_data: str, + query: str, + **llm_kwargs, + ) -> SearchResult: + """Generate answer for a single chunk of community reports.""" + start_time = time.time() + search_prompt = "" + try: + search_prompt = self.map_system_prompt.format(context_data=context_data) + search_messages = [ + {"role": "system", "content": search_prompt}, + {"role": "user", "content": query}, + ] + async with self.semaphore: + search_response = await self.llm.agenerate( + messages=search_messages, streaming=False, **llm_kwargs + ) + log.info("Map response: %s", search_response) + try: + # parse search response json + processed_response = self.parse_search_response(search_response) + except ValueError: + # Clean up and retry parse + try: + # parse search response json + processed_response = self.parse_search_response(search_response) + except ValueError: + log.warning( + "Warning: Error parsing search response json - skipping this batch" + ) + processed_response = [] + + return SearchResult( + response=processed_response, + context_data=context_data, + context_text=context_data, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in _map_response_single_batch") + return SearchResult( + response=[{"answer": "", "score": 0}], + context_data=context_data, + context_text=context_data, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + def parse_search_response(self, search_response: str) -> list[dict[str, Any]]: + """Parse the search response json and return a list of key points. + + Parameters + ---------- + search_response: str + The search response json string + + Returns + ------- + list[dict[str, Any]] + A list of key points, each key point is a dictionary with "answer" and "score" keys + """ + search_response, _j = try_parse_json_object(search_response) + if _j == {}: + return [{"answer": "", "score": 0}] + + parsed_elements = json.loads(search_response).get("points") + if not parsed_elements or not isinstance(parsed_elements, list): + return [{"answer": "", "score": 0}] + + return [ + { + "answer": element["description"], + "score": int(element["score"]), + } + for element in parsed_elements + if "description" in element and "score" in element + ] + + async def _reduce_response( + self, + map_responses: list[SearchResult], + query: str, + **llm_kwargs, + ) -> SearchResult: + """Combine all intermediate responses from single batches into a final answer to the user query.""" + text_data = "" + search_prompt = "" + start_time = time.time() + try: + # collect all key points into a single list to prepare for sorting + key_points = [] + for index, response in enumerate(map_responses): + if not isinstance(response.response, list): + continue + for element in response.response: + if not isinstance(element, dict): + continue + if "answer" not in element or "score" not in element: + continue + key_points.append({ + "analyst": index, + "answer": element["answer"], + "score": element["score"], + }) + + # filter response with score = 0 and rank responses by descending order of score + filtered_key_points = [ + point + for point in key_points + if point["score"] > 0 # type: ignore + ] + + if len(filtered_key_points) == 0 and not self.allow_general_knowledge: + # return no data answer if no key points are found + log.warning( + "Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations." + ) + return SearchResult( + response=NO_DATA_ANSWER, + context_data="", + context_text="", + completion_time=time.time() - start_time, + llm_calls=0, + prompt_tokens=0, + ) + + filtered_key_points = sorted( + filtered_key_points, + key=lambda x: x["score"], # type: ignore + reverse=True, # type: ignore + ) + + data = [] + total_tokens = 0 + for point in filtered_key_points: + formatted_response_data = [] + formatted_response_data.append( + f'----Analyst {point["analyst"] + 1}----' + ) + formatted_response_data.append( + f'Importance Score: {point["score"]}' # type: ignore + ) + formatted_response_data.append(point["answer"]) # type: ignore + formatted_response_text = "\n".join(formatted_response_data) + if ( + total_tokens + + num_tokens(formatted_response_text, self.token_encoder) + > self.max_data_tokens + ): + break + data.append(formatted_response_text) + total_tokens += num_tokens(formatted_response_text, self.token_encoder) + text_data = "\n\n".join(data) + + search_prompt = self.reduce_system_prompt.format( + report_data=text_data, response_type=self.response_type + ) + if self.allow_general_knowledge: + search_prompt += "\n" + self.general_knowledge_inclusion_prompt + search_messages = [ + {"role": "system", "content": search_prompt}, + {"role": "user", "content": query}, + ] + + search_response = await self.llm.agenerate( + search_messages, + streaming=True, + callbacks=self.callbacks, # type: ignore + **llm_kwargs, # type: ignore + ) + return SearchResult( + response=search_response, + context_data=text_data, + context_text=text_data, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + except Exception: + log.exception("Exception in reduce_response") + return SearchResult( + response="", + context_data=text_data, + context_text=text_data, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) diff --git a/func-app/graphrag/query/structured_search/local_search/__init__.py b/func-app/graphrag/query/structured_search/local_search/__init__.py new file mode 100644 index 0000000000..8b8b1e790e --- /dev/null +++ b/func-app/graphrag/query/structured_search/local_search/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The LocalSearch package.""" diff --git a/func-app/graphrag/query/structured_search/local_search/mixed_context.py b/func-app/graphrag/query/structured_search/local_search/mixed_context.py new file mode 100644 index 0000000000..af4c63d55b --- /dev/null +++ b/func-app/graphrag/query/structured_search/local_search/mixed_context.py @@ -0,0 +1,533 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""Algorithms to build context data for local search prompt.""" + +import logging +from typing import Any + +import pandas as pd +from common.graph_db_client import GraphDBClient +from graphrag.config.models.graphdb_config import GraphDBConfig +import tiktoken + +from graphrag.model import ( + CommunityReport, + Covariate, + Entity, + Relationship, + TextUnit, +) +from graphrag.query.context_builder.community_context import ( + build_community_context, +) +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.context_builder.entity_extraction import ( + EntityVectorStoreKey, + map_query_to_entities, +) +from graphrag.query.context_builder.local_context import ( + build_covariates_context, + build_entity_context, + build_relationship_context, + get_candidate_context, +) +from graphrag.query.context_builder.source_context import ( + build_text_unit_context, + count_relationships, +) +from graphrag.query.input.retrieval.community_reports import ( + get_candidate_communities, +) +from graphrag.query.input.retrieval.text_units import get_candidate_text_units +from graphrag.query.llm.base import BaseTextEmbedding +from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.structured_search.base import LocalContextBuilder +from graphrag.vector_stores import BaseVectorStore +from graphrag.vector_stores.kusto import KustoVectorStore + +log = logging.getLogger(__name__) + + +class LocalSearchMixedContext(LocalContextBuilder): + """Build data context for local search prompt combining community reports and entity/relationship/covariate tables.""" + + def __init__( + self, + entities: list[Entity], + entity_text_embeddings: BaseVectorStore, + text_embedder: BaseTextEmbedding, + text_units: list[TextUnit] | None = None, + community_reports: list[CommunityReport] | None = None, + relationships: list[Relationship] | None = None, + covariates: dict[str, list[Covariate]] | None = None, + token_encoder: tiktoken.Encoding | None = None, + embedding_vectorstore_key: str = EntityVectorStoreKey.ID, + is_optimized_search: bool = False, + use_kusto_community_reports: bool = False, + graphdb_config: GraphDBConfig|None = None, + context_id:str = None, + ): + if community_reports is None: + community_reports = [] + if relationships is None: + relationships = [] + if covariates is None: + covariates = {} + if text_units is None: + text_units = [] + self.entities = {entity.id: entity for entity in entities} + self.community_reports = { + community.id: community for community in community_reports + } + self.text_units = {unit.id: unit for unit in text_units} + self.relationships = { + relationship.id: relationship for relationship in relationships + } + self.covariates = covariates + self.entity_text_embeddings = entity_text_embeddings + self.text_embedder = text_embedder + self.token_encoder = token_encoder + self.embedding_vectorstore_key = embedding_vectorstore_key + self.is_optimized_search = is_optimized_search + self.use_kusto_community_reports = use_kusto_community_reports + self.graphdb_config = graphdb_config + self.context_id = context_id + + def filter_by_entity_keys(self, entity_keys: list[int] | list[str]): + """Filter entity text embeddings by entity keys.""" + self.entity_text_embeddings.filter_by_id(entity_keys) + + def build_context( + self, + query: str, + conversation_history: ConversationHistory | None = None, + include_entity_names: list[str] | None = None, + exclude_entity_names: list[str] | None = None, + conversation_history_max_turns: int | None = 5, + conversation_history_user_turns_only: bool = True, + max_tokens: int = 8000, + text_unit_prop: float = 0.5, + community_prop: float = 0.25, + top_k_mapped_entities: int = 10, + top_k_relationships: int = 10, + include_community_rank: bool = False, + include_entity_rank: bool = False, + rank_description: str = "number of relationships", + include_relationship_weight: bool = False, + relationship_ranking_attribute: str = "rank", + return_candidate_context: bool = False, + use_community_summary: bool = False, + min_community_rank: int = 0, + community_context_name: str = "Reports", + column_delimiter: str = "|", + is_optimized_search: bool = False, + **kwargs: dict[str, Any], + ) -> tuple[str | list[str], dict[str, pd.DataFrame]]: + """ + Build data context for local search prompt. + + Build a context by combining community reports and entity/relationship/covariate tables, and text units using a predefined ratio set by summary_prop. + """ + if include_entity_names is None: + include_entity_names = [] + if exclude_entity_names is None: + exclude_entity_names = [] + if community_prop + text_unit_prop > 1: + value_error = ( + "The sum of community_prop and text_unit_prop should not exceed 1." + ) + raise ValueError(value_error) + + # map user query to entities + # if there is conversation history, attached the previous user questions to the current query + if conversation_history: + pre_user_questions = "\n".join( + conversation_history.get_user_turns(conversation_history_max_turns) + ) + query = f"{query}\n{pre_user_questions}" + + selected_entities = map_query_to_entities( + query=query, + text_embedding_vectorstore=self.entity_text_embeddings, + text_embedder=self.text_embedder, + all_entities=list(self.entities.values()), + embedding_vectorstore_key=self.embedding_vectorstore_key, + include_entity_names=include_entity_names, + exclude_entity_names=exclude_entity_names, + k=top_k_mapped_entities, + oversample_scaler=2, + ) + + print("Selected entities titles: ", [entity.title for entity in selected_entities]) + + # build context + final_context = list[str]() + final_context_data = dict[str, pd.DataFrame]() + + if conversation_history: + # build conversation history context + ( + conversation_history_context, + conversation_history_context_data, + ) = conversation_history.build_context( + include_user_turns_only=conversation_history_user_turns_only, + max_qa_turns=conversation_history_max_turns, + column_delimiter=column_delimiter, + max_tokens=max_tokens, + recency_bias=False, + ) + if conversation_history_context.strip() != "": + final_context.append(conversation_history_context) + final_context_data = conversation_history_context_data + max_tokens = max_tokens - num_tokens( + conversation_history_context, self.token_encoder + ) + + if not is_optimized_search: + community_tokens = max(int(max_tokens * community_prop), 0) + community_context, community_context_data = self._build_community_context( + selected_entities=selected_entities, + max_tokens=community_tokens, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + return_candidate_context=return_candidate_context, + context_name=community_context_name, + is_optimized_search=is_optimized_search + ) + if community_context.strip() != "": + final_context.append(community_context) + final_context_data = {**final_context_data, **community_context_data} + + # build local (i.e. entity-relationship-covariate) context + local_prop = 1 - community_prop - text_unit_prop + local_tokens = max(int(max_tokens * local_prop), 0) + local_context, local_context_data = self._build_local_context( + selected_entities=selected_entities, + max_tokens=local_tokens, + include_entity_rank=include_entity_rank, + rank_description=rank_description, + include_relationship_weight=include_relationship_weight, + top_k_relationships=top_k_relationships, + relationship_ranking_attribute=relationship_ranking_attribute, + return_candidate_context=return_candidate_context, + column_delimiter=column_delimiter, + is_optimized_search=is_optimized_search + ) + if local_context.strip() != "": + final_context.append(str(local_context)) + final_context_data = {**final_context_data, **local_context_data} + if not self.is_optimized_search: + # build text unit context + text_unit_tokens = max(int(max_tokens * text_unit_prop), 0) + text_unit_context, text_unit_context_data = self._build_text_unit_context( + selected_entities=selected_entities, + max_tokens=text_unit_tokens, + return_candidate_context=return_candidate_context, + ) + if text_unit_context.strip() != "": + final_context.append(text_unit_context) + final_context_data = {**final_context_data, **text_unit_context_data} + + return ("\n\n".join(final_context), final_context_data) + + def _build_community_context( + self, + selected_entities: list[Entity], + max_tokens: int = 4000, + use_community_summary: bool = False, + column_delimiter: str = "|", + include_community_rank: bool = False, + min_community_rank: int = 0, + return_candidate_context: bool = False, + context_name: str = "Reports", + is_optimized_search: bool = False, + ) -> tuple[str, dict[str, pd.DataFrame]]: + """Add community data to the context window until it hits the max_tokens limit.""" + if len(selected_entities) == 0 or (len(self.community_reports) == 0 and not self.use_kusto_community_reports): + return ("", {context_name.lower(): pd.DataFrame()}) + + community_matches = {} + for entity in selected_entities: + # increase count of the community that this entity belongs to + if entity.community_ids: + for community_id in entity.community_ids: + community_matches[community_id] = ( + community_matches.get(community_id, 0) + 1 + ) + + selected_communities = [] + if self.use_kusto_community_reports: + selected_communities = self.entity_text_embeddings.get_extracted_reports( + community_ids=list(community_matches.keys()) + ) + else: + selected_communities = [ + self.community_reports[community_id] + for community_id in community_matches + if community_id in self.community_reports + ] + + # sort communities by number of matched entities and rank + for community in selected_communities: + if community.attributes is None: + community.attributes = {} + community.attributes["matches"] = community_matches[community.id] + selected_communities.sort( + key=lambda x: (x.attributes["matches"], x.rank), # type: ignore + reverse=True, # type: ignore + ) + for community in selected_communities: + del community.attributes["matches"] # type: ignore + context_data = {} + context_data["reports"] = selected_communities + context_text = "" + if not is_optimized_search: + context_text, context_data = build_community_context( + community_reports=selected_communities, + token_encoder=self.token_encoder, + use_community_summary=use_community_summary, + column_delimiter=column_delimiter, + shuffle_data=False, + include_community_rank=include_community_rank, + min_community_rank=min_community_rank, + max_tokens=max_tokens, + single_batch=True, + context_name=context_name, + ) + if isinstance(context_text, list) and len(context_text) > 0: + context_text = "\n\n".join(context_text) + + if return_candidate_context: + candidate_context_data = get_candidate_communities( + selected_entities=selected_entities, + community_reports=list(self.community_reports.values()), + use_community_summary=use_community_summary, + include_community_rank=include_community_rank, + ) + context_key = context_name.lower() + if context_key not in context_data: + context_data[context_key] = candidate_context_data + context_data[context_key]["in_context"] = False + else: + if ( + "id" in candidate_context_data.columns + and "id" in context_data[context_key].columns + ): + candidate_context_data["in_context"] = candidate_context_data[ + "id" + ].isin( # cspell:disable-line + context_data[context_key]["id"] + ) + context_data[context_key] = candidate_context_data + else: + context_data[context_key]["in_context"] = True + return (str(context_text), context_data) + + def _build_text_unit_context( + self, + selected_entities: list[Entity], + max_tokens: int = 8000, + return_candidate_context: bool = False, + column_delimiter: str = "|", + context_name: str = "Sources", + ) -> tuple[str, dict[str, pd.DataFrame]]: + """Rank matching text units and add them to the context window until it hits the max_tokens limit.""" + if len(selected_entities) == 0 or len(self.text_units) == 0: + return ("", {context_name.lower(): pd.DataFrame()}) + + selected_text_units = list[TextUnit]() + # for each matching text unit, rank first by the order of the entities that match it, then by the number of matching relationships + # that the text unit has with the matching entities + for index, entity in enumerate(selected_entities): + if entity.text_unit_ids: + for text_id in entity.text_unit_ids: + if ( + text_id not in [unit.id for unit in selected_text_units] + and text_id in self.text_units + ): + selected_unit = self.text_units[text_id] + num_relationships = count_relationships( + selected_unit, entity, self.relationships + ) + if selected_unit.attributes is None: + selected_unit.attributes = {} + selected_unit.attributes["entity_order"] = index + selected_unit.attributes["num_relationships"] = ( + num_relationships + ) + selected_text_units.append(selected_unit) + + # sort selected text units by ascending order of entity order and descending order of number of relationships + selected_text_units.sort( + key=lambda x: ( + x.attributes["entity_order"], # type: ignore + -x.attributes["num_relationships"], # type: ignore + ) + ) + + for unit in selected_text_units: + del unit.attributes["entity_order"] # type: ignore + del unit.attributes["num_relationships"] # type: ignore + + context_text, context_data = build_text_unit_context( + text_units=selected_text_units, + token_encoder=self.token_encoder, + max_tokens=max_tokens, + shuffle_data=False, + context_name=context_name, + column_delimiter=column_delimiter, + ) + + if return_candidate_context: + candidate_context_data = get_candidate_text_units( + selected_entities=selected_entities, + text_units=list(self.text_units.values()), + ) + context_key = context_name.lower() + if context_key not in context_data: + context_data[context_key] = candidate_context_data + context_data[context_key]["in_context"] = False + else: + if ( + "id" in candidate_context_data.columns + and "id" in context_data[context_key].columns + ): + candidate_context_data["in_context"] = candidate_context_data[ + "id" + ].isin( # cspell:disable-line + context_data[context_key]["id"] + ) + context_data[context_key] = candidate_context_data + else: + context_data[context_key]["in_context"] = True + return (str(context_text), context_data) + + def _build_local_context( + self, + selected_entities: list[Entity], + max_tokens: int = 8000, + include_entity_rank: bool = False, + rank_description: str = "relationship count", + include_relationship_weight: bool = False, + top_k_relationships: int = 10, + relationship_ranking_attribute: str = "rank", + return_candidate_context: bool = False, + column_delimiter: str = "|", + is_optimized_search: bool = False + ) -> tuple[str, dict[str, pd.DataFrame]]: + """Build data context for local search prompt combining entity/relationship/covariate tables.""" + # build entity context + entity_context, entity_context_data = build_entity_context( + selected_entities=selected_entities, + token_encoder=self.token_encoder, + max_tokens=max_tokens, + column_delimiter=column_delimiter, + include_entity_rank=include_entity_rank, + rank_description=rank_description, + context_name="Entities", + is_optimized_search=is_optimized_search, + ) + entity_tokens = num_tokens(entity_context, self.token_encoder) + + # build relationship-covariate context + added_entities = [] + final_context = [] + final_context_data = {} + + # gradually add entities and associated metadata to the context until we reach limit + graphdb_client=GraphDBClient(self.graphdb_config,self.context_id) if (self.graphdb_config and self.graphdb_config.enabled) else None + for entity in selected_entities: + current_context = [] + current_context_data = {} + added_entities.append(entity) + + # build relationship context + ( + relationship_context, + relationship_context_data, + ) = build_relationship_context( + selected_entities=added_entities, + relationships=list(self.relationships.values()), + token_encoder=self.token_encoder, + max_tokens=max_tokens, + column_delimiter=column_delimiter, + top_k_relationships=top_k_relationships, + include_relationship_weight=include_relationship_weight, + relationship_ranking_attribute=relationship_ranking_attribute, + context_name="Relationships", + is_optimized_search=is_optimized_search, + graphdb_client=graphdb_client, + ) + current_context.append(relationship_context) + current_context_data["relationships"] = relationship_context_data + total_tokens = entity_tokens + num_tokens( + relationship_context, self.token_encoder + ) + + + # build covariate context + for covariate in self.covariates: + covariate_context, covariate_context_data = build_covariates_context( + selected_entities=added_entities, + covariates=self.covariates[covariate], + token_encoder=self.token_encoder, + max_tokens=max_tokens, + column_delimiter=column_delimiter, + context_name=covariate, + is_optimized_search=is_optimized_search + ) + total_tokens += num_tokens(covariate_context, self.token_encoder) + current_context.append(covariate_context) + current_context_data[covariate.lower()] = covariate_context_data + + if total_tokens > max_tokens: + log.info("Reached token limit - reverting to previous context state") + break + + final_context = current_context + final_context_data = current_context_data + + # attach entity context to final context + if graphdb_client: + graphdb_client._client.close() + final_context_text = entity_context + "\n\n" + "\n\n".join(final_context) + final_context_data["entities"] = entity_context_data + + if return_candidate_context: + # we return all the candidate entities/relationships/covariates (not only those that were fitted into the context window) + # and add a tag to indicate which records were included in the context window + candidate_context_data = get_candidate_context( + selected_entities=selected_entities, + entities=list(self.entities.values()), + relationships=list(self.relationships.values()), + covariates=self.covariates, + include_entity_rank=include_entity_rank, + entity_rank_description=rank_description, + include_relationship_weight=include_relationship_weight, + ) + for key in candidate_context_data: + candidate_df = candidate_context_data[key] + if key not in final_context_data: + final_context_data[key] = candidate_df + final_context_data[key]["in_context"] = False + else: + in_context_df = final_context_data[key] + + if "id" in in_context_df.columns and "id" in candidate_df.columns: + candidate_df["in_context"] = candidate_df[ + "id" + ].isin( # cspell:disable-line + in_context_df["id"] + ) + final_context_data[key] = candidate_df + else: + final_context_data[key]["in_context"] = True + + else: + for key in final_context_data: + final_context_data[key]["in_context"] = True + return (final_context_text, final_context_data) diff --git a/func-app/graphrag/query/structured_search/local_search/search.py b/func-app/graphrag/query/structured_search/local_search/search.py new file mode 100644 index 0000000000..597b511222 --- /dev/null +++ b/func-app/graphrag/query/structured_search/local_search/search.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""LocalSearch implementation.""" + +import logging +import time +from typing import Any + +import tiktoken + +from graphrag.query.context_builder.builders import LocalContextBuilder +from graphrag.query.context_builder.conversation_history import ( + ConversationHistory, +) +from graphrag.query.llm.base import BaseLLM, BaseLLMCallback +from graphrag.query.llm.text_utils import num_tokens +from graphrag.query.structured_search.base import BaseSearch, SearchResult +from graphrag.query.structured_search.local_search.system_prompt import ( + LOCAL_SEARCH_SYSTEM_PROMPT, +) + +DEFAULT_LLM_PARAMS = { + "max_tokens": 1500, + "temperature": 0.0, +} + +log = logging.getLogger(__name__) + + +class LocalSearch(BaseSearch): + """Search orchestration for local search mode.""" + + def __init__( + self, + llm: BaseLLM, + context_builder: LocalContextBuilder, + token_encoder: tiktoken.Encoding | None = None, + system_prompt: str = LOCAL_SEARCH_SYSTEM_PROMPT, + response_type: str = "multiple paragraphs", + callbacks: list[BaseLLMCallback] | None = None, + llm_params: dict[str, Any] = DEFAULT_LLM_PARAMS, + context_builder_params: dict | None = None, + ): + super().__init__( + llm=llm, + context_builder=context_builder, + token_encoder=token_encoder, + llm_params=llm_params, + context_builder_params=context_builder_params or {}, + ) + self.system_prompt = system_prompt + self.callbacks = callbacks + self.response_type = response_type + + async def asearch( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Build local search context that fits a single context window and generate answer for the user query.""" + start_time = time.time() + search_prompt = "" + + context_text, context_records = self.context_builder.build_context( + query=query, + conversation_history=conversation_history, + **kwargs, + **self.context_builder_params, + ) + log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query) + try: + search_prompt = self.system_prompt.format( + context_data=context_text, response_type=self.response_type + ) + search_messages = [ + {"role": "system", "content": search_prompt}, + {"role": "user", "content": query}, + ] + + response = await self.llm.agenerate( + messages=search_messages, + streaming=True, + callbacks=self.callbacks, + **self.llm_params, + ) + + return SearchResult( + response=response, + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in _asearch") + return SearchResult( + response="", + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + def search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Build local search context that fits a single context window and generate answer for the user question.""" + + start_time = time.time() + search_prompt = "" + context_text, context_records = self.context_builder.build_context( + query=query, + conversation_history=conversation_history, + **kwargs, + **self.context_builder_params, + ) + log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query) + try: + search_prompt = self.system_prompt.format( + context_data=context_text, response_type=self.response_type + ) + search_messages = [ + {"role": "system", "content": search_prompt}, + {"role": "user", "content": query}, + ] + + response = self.llm.generate( + messages=search_messages, + streaming=True, + callbacks=self.callbacks, + **self.llm_params, + ) + + return SearchResult( + response=response, + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in _map_response_single_batch") + return SearchResult( + response="", + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + + def optimized_search( + self, + query: str, + conversation_history: ConversationHistory | None = None, + **kwargs, + ) -> SearchResult: + """Build local search context data.""" + start_time = time.time() + search_prompt = "" + context_text, context_records = self.context_builder.build_context( + query=query, + conversation_history=conversation_history, + is_optimized_search = self.optimized_search, + **kwargs, + **self.context_builder_params, + ) + log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query) + try: + return SearchResult( + response="", + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) + + except Exception: + log.exception("Exception in _map_response_single_batch") + return SearchResult( + response="", + context_data=context_records, + context_text=context_text, + completion_time=time.time() - start_time, + llm_calls=1, + prompt_tokens=num_tokens(search_prompt, self.token_encoder), + ) diff --git a/func-app/graphrag/query/structured_search/local_search/system_prompt.py b/func-app/graphrag/query/structured_search/local_search/system_prompt.py new file mode 100644 index 0000000000..70b1d12fc3 --- /dev/null +++ b/func-app/graphrag/query/structured_search/local_search/system_prompt.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Local search system prompts.""" + +LOCAL_SEARCH_SYSTEM_PROMPT = """ +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]." + +where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + + +---Data tables--- + +{context_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. + +If you don't know the answer, just say so. Do not make anything up. + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]." + +where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. +""" diff --git a/func-app/graphrag/vector_stores/__init__.py b/func-app/graphrag/vector_stores/__init__.py new file mode 100644 index 0000000000..d4c11760aa --- /dev/null +++ b/func-app/graphrag/vector_stores/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing vector-storage implementations.""" + +from .azure_ai_search import AzureAISearch +from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult +from .lancedb import LanceDBVectorStore +from .typing import VectorStoreFactory, VectorStoreType + +__all__ = [ + "AzureAISearch", + "BaseVectorStore", + "LanceDBVectorStore", + "VectorStoreDocument", + "VectorStoreFactory", + "VectorStoreSearchResult", + "VectorStoreType", +] diff --git a/func-app/graphrag/vector_stores/azure_ai_search.py b/func-app/graphrag/vector_stores/azure_ai_search.py new file mode 100644 index 0000000000..9a53c9a5b3 --- /dev/null +++ b/func-app/graphrag/vector_stores/azure_ai_search.py @@ -0,0 +1,225 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the Azure AI Search vector store implementation.""" + +import json +from typing import Any + +from azure.core.credentials import AzureKeyCredential +from azure.identity import DefaultAzureCredential +from azure.search.documents import SearchClient +from azure.search.documents.indexes import SearchIndexClient +from azure.search.documents.indexes.models import ( + HnswAlgorithmConfiguration, + HnswParameters, + SearchableField, + SearchField, + SearchFieldDataType, + SearchIndex, + SimpleField, + VectorSearch, + VectorSearchAlgorithmMetric, + VectorSearchProfile, +) +from azure.search.documents.models import VectorizedQuery + +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity +from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit + +from .base import ( + DEFAULT_VECTOR_SIZE, + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + + +class AzureAISearch(BaseVectorStore): + """The Azure AI Search vector storage implementation.""" + + index_client: SearchIndexClient + + def connect(self, **kwargs: Any) -> Any: + """Connect to the AzureAI vector store.""" + url = kwargs.get("url", None) + api_key = kwargs.get("api_key", None) + audience = kwargs.get("audience", None) + self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE) + + self.vector_search_profile_name = kwargs.get( + "vector_search_profile_name", "vectorSearchProfile" + ) + + if url: + audience_arg = {"audience": audience} if audience else {} + self.db_connection = SearchClient( + endpoint=url, + index_name=self.collection_name, + credential=AzureKeyCredential(api_key) + if api_key + else DefaultAzureCredential(), + **audience_arg, + ) + self.index_client = SearchIndexClient( + endpoint=url, + credential=AzureKeyCredential(api_key) + if api_key + else DefaultAzureCredential(), + **audience_arg, + ) + else: + not_supported_error = "AAISearchDBClient is not supported on local host." + raise ValueError(not_supported_error) + + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + """Load documents into the Azure AI Search index.""" + if overwrite: + if self.collection_name in self.index_client.list_index_names(): + self.index_client.delete_index(self.collection_name) + + # Configure the vector search profile + vector_search = VectorSearch( + algorithms=[ + HnswAlgorithmConfiguration( + name="HnswAlg", + parameters=HnswParameters( + metric=VectorSearchAlgorithmMetric.COSINE + ), + ) + ], + profiles=[ + VectorSearchProfile( + name=self.vector_search_profile_name, + algorithm_configuration_name="HnswAlg", + ) + ], + ) + + index = SearchIndex( + name=self.collection_name, + fields=[ + SimpleField( + name="id", + type=SearchFieldDataType.String, + key=True, + ), + SearchField( + name="vector", + type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + searchable=True, + vector_search_dimensions=self.vector_size, + vector_search_profile_name=self.vector_search_profile_name, + ), + SearchableField(name="text", type=SearchFieldDataType.String), + SimpleField( + name="attributes", + type=SearchFieldDataType.String, + ), + ], + vector_search=vector_search, + ) + + self.index_client.create_or_update_index( + index, + ) + + batch = [ + { + "id": doc.id, + "vector": doc.vector, + "text": doc.text, + "attributes": json.dumps(doc.attributes), + } + for doc in documents + if doc.vector is not None + ] + + if batch and len(batch) > 0: + self.db_connection.upload_documents(batch) + + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by a list of ids.""" + if include_ids is None or len(include_ids) == 0: + self.query_filter = None + # Returning to keep consistency with other methods, but not needed + return self.query_filter + + # More info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function + # search.in is faster that joined and/or conditions + id_filter = ",".join([f"{id!s}" for id in include_ids]) + self.query_filter = f"search.in(id, '{id_filter}', ',')" + + # Returning to keep consistency with other methods, but not needed + # TODO: Refactor on a future PR + return self.query_filter + + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a vector-based similarity search.""" + vectorized_query = VectorizedQuery( + vector=query_embedding, k_nearest_neighbors=k, fields="vector" + ) + + response = self.db_connection.search( + vector_queries=[vectorized_query], + ) + + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=doc.get("id", ""), + text=doc.get("text", ""), + vector=doc.get("vector", []), + attributes=(json.loads(doc.get("attributes", "{}"))), + ), + # Cosine similarity between 0.333 and 1.000 + # https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results + score=doc["@search.score"], + ) + for doc in response + ] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a text-based similarity search.""" + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector( + query_embedding=query_embedding, k=k + ) + return [] + + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + raise NotImplementedError("Loading entities is not supported for Azure AI Search") + + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + raise NotImplementedError("Extracting entities is not supported for Azure AI Search") + + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + raise NotImplementedError("Loading reports is not supported for Azure AI Search") + + def load_text_units(self, units: list[TextUnit], overwrite: bool = True) -> None: + raise NotImplementedError("load_text_units(): Unsupported for this vector store.") + + def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: + raise NotImplementedError("Extracting reports is not supported for Azure AI Search") + + def setup_entities(self) -> None: + raise NotImplementedError("Setting up entities is not supported for Azure AI Search") + + def setup_reports(self) -> None: + raise NotImplementedError("Setting up reports is not supported for Azure AI Search") + + def setup_text_units(self) -> None: + raise NotImplementedError("setup_text_units(): Unsupported for this vector store.") + + def unload_entities(self) -> None: + raise NotImplementedError("unload_entities(): Unsupported for this vector store.") \ No newline at end of file diff --git a/func-app/graphrag/vector_stores/base.py b/func-app/graphrag/vector_stores/base.py new file mode 100644 index 0000000000..f143bc77f5 --- /dev/null +++ b/func-app/graphrag/vector_stores/base.py @@ -0,0 +1,130 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Base classes for vector stores.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity +from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit + +DEFAULT_VECTOR_SIZE: int = 1536 + + +@dataclass +class VectorStoreDocument: + """A document that is stored in vector storage.""" + + id: str | int + """unique id for the document""" + + text: str | None + vector: list[float] | None + + attributes: dict[str, Any] = field(default_factory=dict) + """store any additional metadata, e.g. title, date ranges, etc""" + + +@dataclass +class VectorStoreSearchResult: + """A vector storage search result.""" + + document: VectorStoreDocument + """Document that was found.""" + + score: float + """Similarity score between 0 and 1. Higher is more similar.""" + + +class BaseVectorStore(ABC): + """The base class for vector storage data-access classes.""" + + def __init__( + self, + collection_name: str, + vector_name: str, + reports_name: str, + text_units_name: str, + db_connection: Any | None = None, + document_collection: Any | None = None, + query_filter: Any | None = None, + **kwargs: Any, + ): + self.collection_name = collection_name + self.vector_name = vector_name + self.reports_name = reports_name + self.text_units_name = text_units_name + self.db_connection = db_connection + self.document_collection = document_collection + self.query_filter = query_filter + self.kwargs = kwargs + + @abstractmethod + def connect(self, **kwargs: Any) -> None: + """Connect to vector storage.""" + + @abstractmethod + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + """Load documents into the vector-store.""" + + @abstractmethod + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform ANN search by vector.""" + + @abstractmethod + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform ANN search by text.""" + + @abstractmethod + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by id.""" + + @abstractmethod + def get_extracted_entities( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + """From a query, build a subtable of entities which is only matching entities.""" + + @abstractmethod + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + """Load entities into the vector-store.""" + + @abstractmethod + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + """Load reports into the vector-store.""" + + @abstractmethod + def get_extracted_reports( + self, community_ids: list[int], **kwargs: Any + ) -> list[CommunityReport]: + """Get reports for a given list of community ids.""" + + @abstractmethod + def setup_entities(self) -> None: + """Setup the entities in the vector-store.""" + + @abstractmethod + def setup_reports(self) -> None: + """Setup the reports in the vector-store.""" + + @abstractmethod + def setup_text_units(self) -> None: + """Setup the reports in the vector-store.""" + + @abstractmethod + def load_text_units(self, units: list[TextUnit], overwrite: bool = True) -> None: + """Load reports into the vector-store.""" + + @abstractmethod + def unload_entities(self) -> None: + """Remove context from the databases.""" \ No newline at end of file diff --git a/func-app/graphrag/vector_stores/kusto.py b/func-app/graphrag/vector_stores/kusto.py new file mode 100644 index 0000000000..273de5ea96 --- /dev/null +++ b/func-app/graphrag/vector_stores/kusto.py @@ -0,0 +1,308 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Azure Kusto vector storage implementation package.""" +import os +import typing +from azure.kusto.data import KustoClient, KustoConnectionStringBuilder +from azure.kusto.data.helpers import dataframe_from_result_table +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity +from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit + +import pandas as pd +from pathlib import Path + +import json +from typing import Any, List, cast + +from graphrag.query.input.loaders.utils import ( + to_list, + to_optional_dict, + to_optional_float, + to_optional_int, + to_optional_list, + to_optional_str, + to_str, +) + +from .base import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + + +class KustoVectorStore(BaseVectorStore): + """The Azure Kusto vector storage implementation.""" + + def connect(self, **kwargs: Any) -> Any: + """ + Connect to the vector storage. + + Args: + **kwargs: Arbitrary keyword arguments containing connection parameters. + - cluster (str): The Kusto cluster URL. + - database (str): The Kusto database name. + - client_id (str): The client ID for AAD authentication. + - client_secret (str): The client secret for AAD authentication. + - authority_id (str): The authority ID (tenant ID) for AAD authentication. + + Returns: + Any: The Kusto client instance. + """ + cluster = kwargs.get("cluster") + database = kwargs.get("database") + client_id = kwargs.get("client_id") + client_secret = kwargs.get("client_secret") + authority_id = kwargs.get("authority_id") + env = os.environ.get("ENVIRONMENT") + if(env == "AZURE"): + kcsb = KustoConnectionStringBuilder.with_aad_managed_service_identity_authentication( + str(cluster), client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3" + ) + elif(env == "DEVELOPMENT"): + kcsb = KustoConnectionStringBuilder.with_aad_device_authentication(str(cluster)) + else: + kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( + str(cluster), str(client_id), str(client_secret), str(authority_id)) + self.client = KustoClient(kcsb) + self.database = database + + def load_documents( + self, documents: List[VectorStoreDocument], overwrite: bool = True + ) -> None: + """ + Load documents into vector storage. + + Args: + documents (List[VectorStoreDocument]): List of documents to be loaded. + overwrite (bool): Whether to overwrite the existing table. Defaults to True. + """ + data = [ + { + "id": document.id, + "name": document.text, + "vector": document.vector, + "attributes": json.dumps(document.attributes), + } + for document in documents + if document.vector is not None + ] + + if len(data) == 0: + return + + # Convert data to DataFrame + df = pd.DataFrame(data) + + # Create or replace table + if overwrite: + command = f".drop table {self.collection_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.collection_name} (id: string, text: string, vector: dynamic, attributes: string)" + self.client.execute(self.database, command) + + # Ingest data + ingestion_command = f".ingest inline into table {self.collection_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + def filter_by_id(self, include_ids: List[str] | List[int]) -> Any: + """ + Build a query filter to filter documents by id. + + Args: + include_ids (List[str] | List[int]): List of document IDs to include in the filter. + + Returns: + Any: The query filter string. + """ + if len(include_ids) == 0: + self.query_filter = None + else: + if isinstance(include_ids[0], str): + id_filter = ", ".join([f"'{id}'" for id in include_ids]) + self.query_filter = f"id in ({id_filter})" + else: + self.query_filter = ( + f"id in ({', '.join([str(id) for id in include_ids])})" + ) + return self.query_filter + + def similarity_search_by_vector( + self, query_embedding: List[float], k: int = 10, **kwargs: Any + ) -> List[VectorStoreSearchResult]: + """ + Perform a vector-based similarity search. A search to find the k nearest neighbors of the given query vector. + + Args: + query_embedding (List[float]): The query embedding vector. + k (int): The number of top results to return. Defaults to 10. + **kwargs: Additional keyword arguments. + + Returns: + List[VectorStoreSearchResult]: List of search results. + """ + query = f""" + let query_vector = dynamic({query_embedding}); + {self.collection_name} + | extend similarity = series_cosine_similarity(query_vector, {self.vector_name}) + | top {k} by similarity desc + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + print("Similarities of the search results:", [row["similarity"] for _, row in df.iterrows()]) + + # Temporary to support the original entity_description_embedding + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=row["id"], + text=row["text"], + vector=row[self.vector_name], + attributes=row["attributes"], + ), + score= 1 + float(row["similarity"]), # 1 + similarity to make it a score between 0 and 2 + ) + for _, row in df.iterrows() + ] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """ + Perform a similarity search using a given input text. + + Args: + text (str): The input text to search for. + text_embedder (TextEmbedder): The text embedder to convert text to vector. + k (int): The number of top results to return. Defaults to 10. + **kwargs: Additional keyword arguments. + + Returns: + List[VectorStoreSearchResult]: List of search results. + """ + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector(query_embedding, k) + return [] + + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + query_embedding = text_embedder(text) + query = f""" + let query_vector = dynamic({query_embedding}); + {self.collection_name} + | extend similarity = series_cosine_similarity(query_vector, {self.vector_name}) + | top {k} by similarity desc + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + + return [ + Entity( + id=row["id"], + title=row["title"], + type=row["type"], + description=row["description"], + graph_embedding=row["graph_embedding"], + text_unit_ids=row["text_unit_ids"], + description_embedding=row["description_embedding"], + short_id="", + community_ids=row["community_ids"], + document_ids=row["document_ids"], + rank=row["rank"], + attributes=row["attributes"], + ) for _, row in df.iterrows() + ] + + def unload_entities(self) -> None: + self.client.execute(self.database,f".drop table {self.collection_name} ifexists") + self.client.execute(self.database,f".drop table {self.text_units_name} ifexists") + self.client.execute(self.database,f".drop table {self.reports_name} ifexists") + + def setup_entities(self) -> None: + command = f".drop table {self.collection_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.collection_name} (id: string, short_id: real, title: string, type: string, description: string, description_embedding: dynamic, name_embedding: dynamic, graph_embedding: dynamic, community_ids: dynamic, text_unit_ids: dynamic, document_ids: dynamic, rank: real, attributes: dynamic)" + self.client.execute(self.database, command) + command = f".alter column {self.collection_name}.graph_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + command = f".alter column {self.collection_name}.description_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + + def load_entities(self, entities: list[Entity], overwrite: bool = False) -> None: + # Convert data to DataFrame + df = pd.DataFrame(entities) + + # Create or replace table + if overwrite: + self.setup_entities() + + # Ingest data + ingestion_command = f".ingest inline into table {self.collection_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + def setup_reports(self) -> None: + command = f".drop table {self.reports_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.reports_name} (id: string, short_id: string, title: string, community_id: string, summary: string, full_content: string, rank: real, summary_embedding: dynamic, full_content_embedding: dynamic, attributes: dynamic)" + self.client.execute(self.database, command) + command = f".alter column {self.reports_name}.summary_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + command = f".alter column {self.reports_name}.full_content_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + + def load_reports(self, reports: list[CommunityReport], overwrite: bool = False) -> None: + # Convert data to DataFrame + df = pd.DataFrame(reports) + + # Create or replace table + if overwrite: + self.setup_reports() + + # Ingest data + ingestion_command = f".ingest inline into table {self.reports_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + def setup_text_units(self) -> None: + command = f".drop table {self.text_units_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.text_units_name} (id: string, text: string, n_tokens: string, document_ids: string, entity_ids: string, relationship_ids: string)" + self.client.execute(self.database, command) + + + def load_text_units(self, units: list[TextUnit], overwrite: bool = False) -> None: + df = pd.DataFrame(units) + if overwrite: + self.setup_text_units() + + ingestion_command = f".ingest inline into table {self.text_units_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + def get_extracted_reports( + self, community_ids: list[int], **kwargs: Any + ) -> list[CommunityReport]: + community_ids = ", ".join([str(id) for id in community_ids]) + query = f""" + {self.reports_name} + | where community_id in ({community_ids}) + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + + return [ + CommunityReport( + id=row["id"], + short_id=row["short_id"], + title=row["title"], + community_id=row["community_id"], + summary=row["summary"], + full_content=row["full_content"], + rank=row["rank"], + summary_embedding=row["summary_embedding"], + full_content_embedding=row["full_content_embedding"], + attributes=row["attributes"], + ) for _, row in df.iterrows() + ] diff --git a/func-app/graphrag/vector_stores/lancedb.py b/func-app/graphrag/vector_stores/lancedb.py new file mode 100644 index 0000000000..fb6447d407 --- /dev/null +++ b/func-app/graphrag/vector_stores/lancedb.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The LanceDB vector storage implementation package.""" + +import lancedb as lancedb # noqa: I001 (Ruff was breaking on this file imports, even tho they were sorted and passed local tests) +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity +from graphrag.model.types import TextEmbedder +from graphrag.model import TextUnit + +import json +from typing import Any + +import pyarrow as pa + +from .base import ( + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + + +class LanceDBVectorStore(BaseVectorStore): + """The LanceDB vector storage implementation.""" + + def connect(self, **kwargs: Any) -> Any: + """Connect to the vector storage.""" + db_uri = kwargs.get("db_uri", "./lancedb") + self.db_connection = lancedb.connect(db_uri) # type: ignore + + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + """Load documents into vector storage.""" + data = [ + { + "id": document.id, + "text": document.text, + "vector": document.vector, + "attributes": json.dumps(document.attributes), + } + for document in documents + if document.vector is not None + ] + + if len(data) == 0: + data = None + + schema = pa.schema([ + pa.field("id", pa.string()), + pa.field("text", pa.string()), + pa.field("vector", pa.list_(pa.float64())), + pa.field("attributes", pa.string()), + ]) + if overwrite: + if data: + self.document_collection = self.db_connection.create_table( + self.collection_name, data=data, mode="overwrite" + ) + else: + self.document_collection = self.db_connection.create_table( + self.collection_name, schema=schema, mode="overwrite" + ) + else: + # add data to existing table + self.document_collection = self.db_connection.open_table( + self.collection_name + ) + if data: + self.document_collection.add(data) + + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by id.""" + if len(include_ids) == 0: + self.query_filter = None + else: + if isinstance(include_ids[0], str): + id_filter = ", ".join([f"'{id}'" for id in include_ids]) + self.query_filter = f"id in ({id_filter})" + else: + self.query_filter = ( + f"id in ({', '.join([str(id) for id in include_ids])})" + ) + return self.query_filter + + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a vector-based similarity search.""" + if self.query_filter: + docs = ( + self.document_collection.search(query=query_embedding) + .where(self.query_filter, prefilter=True) + .limit(k) + .to_list() + ) + else: + docs = ( + self.document_collection.search(query=query_embedding) + .limit(k) + .to_list() + ) + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=doc["id"], + text=doc["text"], + vector=doc["vector"], + attributes=json.loads(doc["attributes"]), + ), + score=1 - abs(float(doc["_distance"])), + ) + for doc in docs + ] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a similarity search using a given input text.""" + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector(query_embedding, k) + return [] + + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + raise NotImplementedError("Loading entities is not supported for LanceDB") + + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[Entity]: + raise NotImplementedError("Extracting entities is not supported for LanceDB") + + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + raise NotImplementedError("Loading reports is not supported for LanceDB") + + def load_text_units(self, units: list[TextUnit], overwrite: bool = True) -> None: + raise NotImplementedError("load_text_units(): Unsupported for this vector store.") + + def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: + raise NotImplementedError("Extracting community reports is not supported for LanceDB") + + def setup_entities(self) -> None: + raise NotImplementedError("Setting up entities is not supported for LanceDB") + + def setup_reports(self) -> None: + raise NotImplementedError("Setting up community reports is not supported for LanceDB") + + def setup_text_units(self) -> None: + raise NotImplementedError("setup_text_units(): Unsupported for this vector store.") + + def unload_entities(self) -> None: + raise NotImplementedError("unload_entities(): Unsupported for this vector store.") diff --git a/func-app/graphrag/vector_stores/typing.py b/func-app/graphrag/vector_stores/typing.py new file mode 100644 index 0000000000..459d5b5f56 --- /dev/null +++ b/func-app/graphrag/vector_stores/typing.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the supported vector store types.""" + +from enum import Enum +from typing import ClassVar + +from .azure_ai_search import AzureAISearch +from .lancedb import LanceDBVectorStore +from .kusto import KustoVectorStore + + +class VectorStoreType(str, Enum): + """The supported vector store types.""" + + LanceDB = "lancedb" + AzureAISearch = "azure_ai_search" + Kusto = "kusto" + + +class VectorStoreFactory: + """A factory class for creating vector stores.""" + + vector_store_types: ClassVar[dict[str, type]] = {} + + @classmethod + def register(cls, vector_store_type: str, vector_store: type): + """Register a vector store type.""" + cls.vector_store_types[vector_store_type] = vector_store + + @classmethod + def get_vector_store( + cls, vector_store_type: VectorStoreType | str, kwargs: dict + ) -> LanceDBVectorStore | AzureAISearch | KustoVectorStore: + """Get the vector store type from a string.""" + match vector_store_type: + case VectorStoreType.LanceDB: + return LanceDBVectorStore(**kwargs) + case VectorStoreType.AzureAISearch: + return AzureAISearch(**kwargs) + case VectorStoreType.Kusto: + return KustoVectorStore(**kwargs) + case _: + if vector_store_type in cls.vector_store_types: + return cls.vector_store_types[vector_store_type](**kwargs) + msg = f"Unknown vector store type: {vector_store_type}" + raise ValueError(msg) diff --git a/func-app/host.json b/func-app/host.json new file mode 100644 index 0000000000..9df913614d --- /dev/null +++ b/func-app/host.json @@ -0,0 +1,15 @@ +{ + "version": "2.0", + "logging": { + "applicationInsights": { + "samplingSettings": { + "isEnabled": true, + "excludedTypes": "Request" + } + } + }, + "extensionBundle": { + "id": "Microsoft.Azure.Functions.ExtensionBundle", + "version": "[4.*, 5.0.0)" + } +} \ No newline at end of file diff --git a/func-app/prompts/claim_extraction.txt b/func-app/prompts/claim_extraction.txt new file mode 100644 index 0000000000..0b795c3465 --- /dev/null +++ b/func-app/prompts/claim_extraction.txt @@ -0,0 +1,52 @@ + +-Target activity- +You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document. + +-Goal- +Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities. + +-Steps- +1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types. +2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim. +For each claim, extract the following information: +- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1. +- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**. +- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type +- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified. +- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references. +- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**. +- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim. + +Format each claim as ({tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in English as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +-Examples- +Example 1: +Entity specification: organization +Claim description: red flags associated with an entity +Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. +Output: + +(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +{completion_delimiter} + +Example 2: +Entity specification: Company A, Person C +Claim description: red flags associated with an entity +Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. +Output: + +(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +{record_delimiter} +(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015) +{completion_delimiter} + +-Real Data- +Use the following input for your answer. +Entity specification: {entity_specs} +Claim description: {claim_description} +Text: {input_text} +Output: \ No newline at end of file diff --git a/func-app/prompts/community_report.txt b/func-app/prompts/community_report.txt new file mode 100644 index 0000000000..d71440ab2f --- /dev/null +++ b/func-app/prompts/community_report.txt @@ -0,0 +1,146 @@ + +You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network. + +# Goal +Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims. + +# Report Structure + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. +- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format: + {{ + "title": , + "summary": , + "rating": , + "rating_explanation": , + "findings": [ + {{ + "summary":, + "explanation": + }}, + {{ + "summary":, + "explanation": + }} + ] + }} + +# Grounding Rules + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." + +where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +# Example Input +----------- +Text: + +Entities + +id,entity,description +5,VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March +6,HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza + +Relationships + +id,source,target,description +37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March +38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza +39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza +40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza +41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march +43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March + +Output: +{{ + "title": "Verdant Oasis Plaza and Unity March", + "summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.", + "rating": 5.0, + "rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.", + "findings": [ + {{ + "summary": "Verdant Oasis Plaza as the central location", + "explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes. [Data: Entities (5), Relationships (37, 38, 39, 40, 41,+more)]" + }}, + {{ + "summary": "Harmony Assembly's role in the community", + "explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community. [Data: Entities(6), Relationships (38, 43)]" + }}, + {{ + "summary": "Unity March as a significant event", + "explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community. [Data: Relationships (39)]" + }}, + {{ + "summary": "Role of Tribune Spotlight", + "explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved. [Data: Relationships (40)]" + }} + ] +}} + + +# Real Data + +Use the following text for your answer. Do not make anything up in your answer. + +Text: +{input_text} + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. +- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format: + {{ + "title": , + "summary": , + "rating": , + "rating_explanation": , + "findings": [ + {{ + "summary":, + "explanation": + }}, + {{ + "summary":, + "explanation": + }} + ] + }} + +# Grounding Rules + +Points supported by data should list their data references as follows: + +"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." + +Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." + +where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + +Output: \ No newline at end of file diff --git a/func-app/prompts/entity_extraction.txt b/func-app/prompts/entity_extraction.txt new file mode 100644 index 0000000000..d47747f7cf --- /dev/null +++ b/func-app/prompts/entity_extraction.txt @@ -0,0 +1,99 @@ + +-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter} + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity + Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +###################### +-Examples- +###################### +Example 1: + +Entity_types: [person, technology, mission, organization, location] +Text: +while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. + +Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.” + +The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. + +It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths +################ +Output: +("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter} +("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter} +("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter} +("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter} +("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}7){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}6){record_delimiter} +("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}8){record_delimiter} +("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}5){record_delimiter} +("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}9){completion_delimiter} +############################# +Example 2: + +Entity_types: [person, technology, mission, organization, location] +Text: +They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve. + +Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril. + +Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly +############# +Output: +("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter} +("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter} +("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter} +("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}7){record_delimiter} +("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}9){completion_delimiter} +############################# +Example 3: + +Entity_types: [person, role, technology, organization, event, location, concept] +Text: +their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data. + +"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning." + +Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back." + +Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history. + +The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation +############# +Output: +("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter} +("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter} +("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter} +("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter} +("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter} +("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter} +("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}9){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}10){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}8){record_delimiter} +("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}7){completion_delimiter} +############################# +-Real Data- +###################### +Entity_types: {entity_types} +Text: {input_text} +###################### +Output: \ No newline at end of file diff --git a/func-app/prompts/summarize_descriptions.txt b/func-app/prompts/summarize_descriptions.txt new file mode 100644 index 0000000000..54feaf32d4 --- /dev/null +++ b/func-app/prompts/summarize_descriptions.txt @@ -0,0 +1,13 @@ + +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Entities: {entity_name} +Description List: {description_list} +####### +Output: diff --git a/func-app/requirements.txt b/func-app/requirements.txt new file mode 100644 index 0000000000..4b51af40be --- /dev/null +++ b/func-app/requirements.txt @@ -0,0 +1,142 @@ +aenum==3.1.15 +aiofiles==24.1.0 +aiohappyeyeballs==2.4.0 +aiohttp==3.10.5 +aiolimiter==1.1.0 +aiosignal==1.3.1 +annotated-types==0.7.0 +anyio==4.4.0 +anytree==2.12.1 +asttokens==2.4.1 +async-timeout==4.0.3 +attrs==24.2.0 +autograd==1.7.0 +azure-common==1.1.28 +azure-core==1.31.0 +azure-cosmos==4.7.0 +azure-identity==1.17.1 +azure-kusto-data==4.5.1 +azure-search-documents==11.5.1 +azure-storage-blob==12.23.0 +beartype==0.18.5 +cachetools==5.5.0 +certifi==2024.8.30 +cffi==1.17.1 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpickle==3.0.0 +colorama==0.4.6 +contourpy==1.3.0 +cramjam==2.8.3 +cryptography==43.0.1 +cycler==0.12.1 +dask-expr==1.1.14 +dask==2024.9.0 +dask[dataframe]==2024.9.0 +datashaper==0.0.49 +decorator==5.1.1 +deprecation==2.1.0 +devtools==0.12.2 +diskcache==5.6.3 +distro==1.9.0 +environs==11.0.0 +exceptiongroup==1.2.2 +executing==2.1.0 +fastparquet==2024.5.0 +fonttools==4.53.1 +frozenlist==1.4.1 +fsspec==2024.9.0 +gensim==4.3.3 +graspologic-native==1.2.1 +graspologic==3.4.1 +gremlinpython==3.7.2 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.2 +hyppo==0.4.0 +idna==3.10 +ijson==3.3.0 +importlib-metadata==8.5.0 +isodate==0.6.1 +jiter==0.5.0 +joblib==1.4.2 +json-repair==0.25.3 +jsonschema-specifications==2023.12.1 +jsonschema==4.23.0 +kiwisolver==1.4.7 +lancedb==0.11.0 +linkify-it-py==2.0.3 +llvmlite==0.43.0 +locket==1.0.0 +markdown-it-py==3.0.0 +markdown-it-py[linkify,plugins]==3.0.0 +marshmallow==3.22.0 +matplotlib==3.9.2 +mdit-py-plugins==0.4.2 +mdurl==0.1.2 +msal-extensions==1.2.0 +msal==1.31.0 +multidict==6.1.0 +nest-asyncio==1.6.0 +networkx==3.3 +nltk==3.8.1 +numba==0.60.0 +numpy==1.26.4 +openai==1.46.0 +overrides==7.7.0 +packaging==24.1 +pandas==2.2.2 +partd==1.4.2 +patsy==0.5.6 +pillow==10.4.0 +portalocker==2.10.1 +pot==0.9.4 +psutil==6.0.0 +py==1.11.0 +pyaml-env==1.2.1 +pyarrow==15.0.2 +pycparser==2.22 +pydantic-core==2.23.4 +pydantic==2.9.2 +pygments==2.18.0 +pyjwt[crypto]==2.9.0 +pylance==0.15.0 +pynndescent==0.5.13 +pyparsing==3.1.4 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +pytz==2024.2 +pywin32==306 ## This needs to be commented when deploying to the func app. +pyyaml==6.0.2 +ratelimiter==1.2.0.post0 +referencing==0.35.1 +regex==2024.9.11 +requests==2.32.3 +retry==0.9.2 +rich==13.8.1 +rpds-py==0.20.0 +scikit-learn==1.5.2 +scipy==1.12.0 +seaborn==0.13.2 +six==1.16.0 +smart-open==7.0.4 +sniffio==1.3.1 +statsmodels==0.14.3 +swifter==1.4.0 +tenacity==8.5.0 +textual==0.74.0 +threadpoolctl==3.5.0 +tiktoken==0.7.0 +toolz==0.12.1 +tqdm==4.66.5 +typing-extensions==4.12.2 +tzdata==2024.1 +uc-micro-py==1.0.3 +umap-learn==0.5.6 +urllib3==2.2.3 +wrapt==1.16.0 +yarl==1.11.1 +zipp==3.20.2 +azure.functions +future +pandas \ No newline at end of file diff --git a/func-app/settings/settings.yaml b/func-app/settings/settings.yaml new file mode 100644 index 0000000000..1e6a36da3e --- /dev/null +++ b/func-app/settings/settings.yaml @@ -0,0 +1,152 @@ + +encoding_model: cl100k_base +skip_workflows: [] +llm: + api_key: ${GRAPHRAG_API_KEY} + type: azure_openai_chat # openai_chat # or azure_openai_chat + model: gpt-4o + model_supports_json: true # recommended if this is available for your model. + # max_tokens: 4000 + # request_timeout: 180.0 + api_base: https://spe-pds-ais-ai.openai.azure.com + api_version: 2024-04-01-preview + # organization: + deployment_name: spepdsaigpt + # tokens_per_minute: 150_000 # set a leaky bucket throttle + # requests_per_minute: 10_000 # set a leaky bucket throttle + # max_retries: 10 + # max_retry_wait: 10.0 + # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times + # concurrent_requests: 25 # the number of parallel inflight requests that may be made + # temperature: 0 # temperature for sampling + # top_p: 1 # top-p sampling + # n: 1 # Number of completions to generate + +parallelization: + stagger: 0.3 + # num_threads: 50 # the number of threads to use for parallel processing + +async_mode: threaded # or asyncio + +embeddings: + ## parallelization: override the global parallelization settings for embeddings + async_mode: threaded # or asyncio + llm: + api_key: ${GRAPHRAG_API_KEY} + type: azure_openai_embedding #openai_embedding # or azure_openai_embedding + model: text-embedding-ada-002 + api_base: https://spe-pds-ais-ai.openai.azure.com + api_version: 2024-04-01-preview + # organization: + deployment_name: spepdsaista + # tokens_per_minute: 150_000 # set a leaky bucket throttle + # requests_per_minute: 10_000 # set a leaky bucket throttle + # max_retries: 10 + # max_retry_wait: 10.0 + # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times + # concurrent_requests: 25 # the number of parallel inflight requests that may be made + # batch_size: 16 # the number of documents to send in a single request + # batch_max_tokens: 8191 # the maximum number of tokens to send in a single request + # target: required # or optional + +chunks: + size: 500 + overlap: 100 + group_by_columns: [id] # by default, we don't allow chunks to cross documents + +input: + type: file # or blob + file_type: text # or csv + base_dir: "input" + file_encoding: utf-8 + file_pattern: ".*\\.txt$" + +cache: + type: file # or blob + base_dir: "cache" + # connection_string: + # container_name: + +storage: + type: file # or blob + base_dir: "output/${timestamp}/artifacts" + # connection_string: + # container_name: + +reporting: + type: file # or console, blob + base_dir: "output/${timestamp}/reports" + # connection_string: + # container_name: + +entity_extraction: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/entity_extraction.txt" + entity_types: [organization,person,geo,event] + max_gleanings: 1 + +summarize_descriptions: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/summarize_descriptions.txt" + max_length: 500 + +claim_extraction: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + # enabled: true + prompt: "prompts/claim_extraction.txt" + description: "Any claims or facts that could be relevant to information discovery." + max_gleanings: 1 + +community_reports: + ## llm: override the global llm settings for this task + ## parallelization: override the global parallelization settings for this task + ## async_mode: override the global async_mode settings for this task + prompt: "prompts/community_report.txt" + max_length: 2000 + max_input_length: 8000 + +cluster_graph: + max_cluster_size: 10 + +embed_graph: + enabled: true # if true, will generate node2vec embeddings for nodes + # num_walks: 10 + # walk_length: 40 + # window_size: 2 + # iterations: 3 + # random_seed: 597832 + +umap: + enabled: false # if true, will generate UMAP embeddings for nodes + +snapshots: + graphml: false + raw_entities: false + top_level_nodes: false + +local_search: + # text_unit_prop: 0.5 + # community_prop: 0.1 + # conversation_history_max_turns: 5 + # top_k_mapped_entities: 10 + # top_k_relationships: 10 + # llm_temperature: 0 # temperature for sampling + # llm_top_p: 1 # top-p sampling + # llm_n: 1 # Number of completions to generate + # max_tokens: 12000 + +global_search: + # llm_temperature: 0 # temperature for sampling + # llm_top_p: 1 # top-p sampling + # llm_n: 1 # Number of completions to generate + # max_tokens: 12000 + # data_max_tokens: 12000 + # map_max_tokens: 1000 + # reduce_max_tokens: 2000 + # concurrency: 32 From 2f32844cbf3be06c92ccfa7795737388c306c98b Mon Sep 17 00:00:00 2001 From: Prateek Jain Date: Mon, 23 Sep 2024 10:41:46 -0700 Subject: [PATCH 83/87] added the local settings file forcefully --- func-app/local.settings.json | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 func-app/local.settings.json diff --git a/func-app/local.settings.json b/func-app/local.settings.json new file mode 100644 index 0000000000..504fd3d778 --- /dev/null +++ b/func-app/local.settings.json @@ -0,0 +1,13 @@ +{ + "IsEncrypted": false, + "Values": { + "FUNCTIONS_WORKER_RUNTIME": "python", + "AzureWebJobsStorage": "${AZURE_STORAGE_CONNECTIONSTRING}", + "StorageKey" : "${AZURE_STORAGE_KEY}", + "APPINSIGHTS_INSTRUMENTATIONKEY" : "${AZURE_APPLICATION_INSIGHTS_INSTRUMENTATION_KEY}" + }, + "ConnectionStrings": { + "DbConnectionString": "" + } + } + \ No newline at end of file From cdad2086b52f1bdc6b04b4efed47ffb8a179c5d3 Mon Sep 17 00:00:00 2001 From: logomachic Date: Mon, 23 Sep 2024 11:12:01 -0700 Subject: [PATCH 84/87] Update local.settings.json Syntax update --- func-app/local.settings.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/func-app/local.settings.json b/func-app/local.settings.json index 504fd3d778..f4addbf429 100644 --- a/func-app/local.settings.json +++ b/func-app/local.settings.json @@ -9,5 +9,5 @@ "ConnectionStrings": { "DbConnectionString": "" } - } - \ No newline at end of file +} + From 08d80b8810764a85df0dec5dca60783f8b2e02d1 Mon Sep 17 00:00:00 2001 From: amritpalms Date: Mon, 23 Sep 2024 18:41:12 -0700 Subject: [PATCH 85/87] config for debugger --- func-app/.vscode/launch.json | 20 +++++++-------- func-app/.vscode/tasks.json | 48 ++++++++++++++++++------------------ 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/func-app/.vscode/launch.json b/func-app/.vscode/launch.json index a90b7259e1..5cb5555156 100644 --- a/func-app/.vscode/launch.json +++ b/func-app/.vscode/launch.json @@ -1,13 +1,13 @@ { - "version": "0.2.0", - "configurations": [ - { - "name": "Attach to Python Functions", - "type": "python", - "request": "attach", - "port": 7071, - "preLaunchTask": "func: host start", - "justMyCode": true - } + "version": "0.2.0", + "configurations": [ + { + "name": "Attach to Python Functions", + "type": "python", + "request": "attach", + "port": 9091, + "preLaunchTask": "func: host start", + "justMyCode": true + } ] } \ No newline at end of file diff --git a/func-app/.vscode/tasks.json b/func-app/.vscode/tasks.json index 808884468c..4dfedc78dd 100644 --- a/func-app/.vscode/tasks.json +++ b/func-app/.vscode/tasks.json @@ -1,26 +1,26 @@ { - "version": "2.0.0", - "tasks": [ - { - "type": "func", - "command": "host start", - "problemMatcher": "$func-watch", - "isBackground": true, - "dependsOn": "pipInstall" - }, - { - "label": "pipInstall", - "type": "shell", - "osx": { - "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" - }, - "windows": { - "command": "${config:azureFunctions.pythonVenv}\\Scripts\\python -m pip install -r requirements.txt" - }, - "linux": { - "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" - }, - "problemMatcher": [] - } - ] + "version": "2.0.0", + "tasks": [ + { + "type": "func", + "command": "host start", + "problemMatcher": "$func-python-watch", + "isBackground": true, + "dependsOn": "pipInstall" + }, + { + "label": "pipInstall", + "type": "shell", + "osx": { + "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" + }, + "windows": { + "command": "${config:azureFunctions.pythonVenv}\\Scripts\\python -m pip install -r requirements.txt" + }, + "linux": { + "command": "${config:azureFunctions.pythonVenv}/bin/python -m pip install -r requirements.txt" + }, + "problemMatcher": [] + } + ] } \ No newline at end of file From 984fa0db2ccb3db76f1171fe0e260e744ea95b87 Mon Sep 17 00:00:00 2001 From: amritpalms Date: Mon, 23 Sep 2024 20:21:47 -0700 Subject: [PATCH 86/87] DefaultAuthCredes for llm --- func-app/graphrag/llm/openai/create_openai_client.py | 4 ++-- func-app/graphrag/query/factories.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/func-app/graphrag/llm/openai/create_openai_client.py b/func-app/graphrag/llm/openai/create_openai_client.py index cd149323c6..1bc0c2e042 100644 --- a/func-app/graphrag/llm/openai/create_openai_client.py +++ b/func-app/graphrag/llm/openai/create_openai_client.py @@ -6,7 +6,7 @@ import logging from functools import cache -from azure.identity import ManagedIdentityCredential, get_bearer_token_provider +from azure.identity import ManagedIdentityCredential, get_bearer_token_provider, DefaultAzureCredential from openai import AsyncAzureOpenAI, AsyncOpenAI from .openai_configuration import OpenAIConfiguration @@ -40,7 +40,7 @@ def create_openai_client( return AsyncAzureOpenAI( api_key=configuration.api_key if configuration.api_key else None, azure_ad_token_provider=get_bearer_token_provider( - ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint + DefaultAzureCredential(managed_identity_client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3", exclude_interactive_browser_credential = False), cognitive_services_endpoint ) if not configuration.api_key else None, diff --git a/func-app/graphrag/query/factories.py b/func-app/graphrag/query/factories.py index 28caf61bb0..a3b012e183 100644 --- a/func-app/graphrag/query/factories.py +++ b/func-app/graphrag/query/factories.py @@ -5,7 +5,7 @@ from graphrag.config.models.graphdb_config import GraphDBConfig import tiktoken -from azure.identity import ManagedIdentityCredential, get_bearer_token_provider +from azure.identity import ManagedIdentityCredential, get_bearer_token_provider, DefaultAzureCredential from graphrag.config import ( GraphRagConfig, @@ -53,7 +53,8 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI: api_key=config.llm.api_key, azure_ad_token_provider=( get_bearer_token_provider( - ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint + DefaultAzureCredential(managed_identity_client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3", exclude_interactive_browser_credential = False), cognitive_services_endpoint + ) if is_azure_client and not config.llm.api_key else None @@ -85,7 +86,8 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding: api_key=config.embeddings.llm.api_key, azure_ad_token_provider=( get_bearer_token_provider( - ManagedIdentityCredential(client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3"), cognitive_services_endpoint + DefaultAzureCredential(managed_identity_client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3", exclude_interactive_browser_credential = False), cognitive_services_endpoint + ) if is_azure_client and not config.embeddings.llm.api_key else None From 81f0bca23a96f8461164d844cd0d304221a51af5 Mon Sep 17 00:00:00 2001 From: amritpalms Date: Mon, 23 Sep 2024 20:53:14 -0700 Subject: [PATCH 87/87] fix for browser not opening --- func-app/graphrag/llm/openai/create_openai_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/func-app/graphrag/llm/openai/create_openai_client.py b/func-app/graphrag/llm/openai/create_openai_client.py index 1bc0c2e042..d45aa4e6e8 100644 --- a/func-app/graphrag/llm/openai/create_openai_client.py +++ b/func-app/graphrag/llm/openai/create_openai_client.py @@ -41,9 +41,7 @@ def create_openai_client( api_key=configuration.api_key if configuration.api_key else None, azure_ad_token_provider=get_bearer_token_provider( DefaultAzureCredential(managed_identity_client_id="295ce65c-28c6-4763-be6f-a5eb36c3ceb3", exclude_interactive_browser_credential = False), cognitive_services_endpoint - ) - if not configuration.api_key - else None, + ), organization=configuration.organization, # Azure-Specifics api_version=configuration.api_version,