diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index db6b92a4..a13b6f6d 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -65,7 +65,7 @@ import sqlparse from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError -from ..types import TrainingPlan, TrainingPlanItem +from ..types import TrainingPlan, TrainingPlanItem, TableMetadata from ..utils import validate_config_path @@ -210,6 +210,54 @@ def extract_sql(self, llm_response: str) -> str: return llm_response + def extract_table_metadata(ddl: str) -> TableMetadata: + """ + Example: + ```python + vn.extract_table_metadata("CREATE TABLE hive.bi_ads.customers (id INT, name TEXT, sales DECIMAL)") + ``` + + Extracts the table metadata from a DDL statement. This is useful in case the DDL statement contains other information besides the table metadata. + Override this function if your DDL statements need custom extraction logic. + + Args: + ddl (str): The DDL statement. + + Returns: + TableMetadata: The extracted table metadata. + """ + pattern_with_catalog_schema = re.compile( + r'CREATE TABLE\s+(\w+)\.(\w+)\.(\w+)\s*\(', + re.IGNORECASE + ) + pattern_with_schema = re.compile( + r'CREATE TABLE\s+(\w+)\.(\w+)\s*\(', + re.IGNORECASE + ) + pattern_with_table = re.compile( + r'CREATE TABLE\s+(\w+)\s*\(', + re.IGNORECASE + ) + + match_with_catalog_schema = pattern_with_catalog_schema.search(ddl) + match_with_schema = pattern_with_schema.search(ddl) + match_with_table = pattern_with_table.search(ddl) + + if match_with_catalog_schema: + catalog = match_with_catalog_schema.group(1) + schema = match_with_catalog_schema.group(2) + table_name = match_with_catalog_schema.group(3) + return TableMetadata(catalog, schema, table_name) + elif match_with_schema: + schema = match_with_schema.group(1) + table_name = match_with_schema.group(2) + return TableMetadata(None, schema, table_name) + elif match_with_table: + table_name = match_with_table.group(1) + return TableMetadata(None, None, table_name) + else: + return TableMetadata() + def is_sql_valid(self, sql: str) -> bool: """ Example: @@ -395,6 +443,31 @@ def get_related_ddl(self, question: str, **kwargs) -> list: """ pass + @abstractmethod + def search_tables_metadata(self, + engine: str = None, + catalog: str = None, + schema: str = None, + table_name: str = None, + ddl: str = None, + size: int = 10, + **kwargs) -> list: + """ + This method is used to get similar tables metadata. + + Args: + engine (str): The database engine. + catalog (str): The catalog. + schema (str): The schema. + table_name (str): The table name. + ddl (str): The DDL statement. + size (int): The number of tables to return. + + Returns: + list: A list of tables metadata. + """ + pass + @abstractmethod def get_related_documentation(self, question: str, **kwargs) -> list: """ @@ -423,12 +496,13 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: pass @abstractmethod - def add_ddl(self, ddl: str, **kwargs) -> str: + def add_ddl(self, ddl: str, engine: str = None, **kwargs) -> str: """ This method is used to add a DDL statement to the training data. Args: ddl (str): The DDL statement to add. + engine (str): The database engine that the DDL statement applies to. Returns: str: The ID of the training data that was added. @@ -1778,6 +1852,7 @@ def train( question: str = None, sql: str = None, ddl: str = None, + engine: str = None, documentation: str = None, plan: TrainingPlan = None, ) -> str: @@ -1798,8 +1873,11 @@ def train( question (str): The question to train on. sql (str): The SQL query to train on. ddl (str): The DDL statement. + engine (str): The database engine. documentation (str): The documentation to train on. plan (TrainingPlan): The training plan to train on. + Returns: + str: The training pl """ if question and not sql: @@ -1817,12 +1895,12 @@ def train( if ddl: print("Adding ddl:", ddl) - return self.add_ddl(ddl) + return self.add_ddl(ddl=ddl, engine=engine) if plan: for item in plan._plan: if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: - self.add_ddl(item.item_value) + self.add_ddl(ddl=item.item_value, engine=engine) elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS: self.add_documentation(item.item_value) elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL: diff --git a/src/vanna/opensearch/opensearch_vector.py b/src/vanna/opensearch/opensearch_vector.py index 7fab1ecd..a6047a96 100644 --- a/src/vanna/opensearch/opensearch_vector.py +++ b/src/vanna/opensearch/opensearch_vector.py @@ -4,8 +4,10 @@ import pandas as pd from opensearchpy import OpenSearch +from ..types import TableMetadata from ..base import VannaBase +from ..utils import deterministic_uuid class OpenSearch_VectorStore(VannaBase): @@ -56,6 +58,18 @@ def __init__(self, config=None): }, "mappings": { "properties": { + "engine": { + "type": "keyword", + }, + "catalog": { + "type": "keyword", + }, + "schema": { + "type": "keyword", + }, + "table_name": { + "type": "keyword", + }, "ddl": { "type": "text", }, @@ -92,6 +106,8 @@ def __init__(self, config=None): if config is not None and "es_question_sql_index_settings" in config: question_sql_index_settings = config["es_question_sql_index_settings"] + self.n_results = config.get("n_results", 10) + self.document_index_settings = document_index_settings self.ddl_index_settings = ddl_index_settings self.question_sql_index_settings = question_sql_index_settings @@ -231,10 +247,29 @@ def create_index_if_not_exists(self, index_name: str, print(f"Error creating index: {index_name} ", e) return False - def add_ddl(self, ddl: str, **kwargs) -> str: + def calculate_md5(self, string: str) -> str: + # 将字符串编码为 bytes + string_bytes = self.encode('utf-8') + # 计算 MD5 哈希值 + md5_hash = hashlib.md5(string_bytes) + # 获取十六进制表示的哈希值 + md5_hex = md5_hash.hexdigest() + return md5_hex + + def add_ddl(self, ddl: str, engine: str = None, + **kwargs) -> str: # Assuming that you have a DDL index in your OpenSearch - id = str(uuid.uuid4()) + "-ddl" + table_metadata = VannaBase.extract_table_metadata(ddl) + full_table_name = table_metadata.get_full_table_name() + if full_table_name is not None and engine is not None: + id = deterministic_uuid(engine + "-" + full_table_name) + "-ddl" + else: + id = str(uuid.uuid4()) + "-ddl" ddl_dict = { + "engine": engine, + "catalog": table_metadata.catalog, + "schema": table_metadata.schema, + "table_name": table_metadata.table_name, "ddl": ddl } response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id, @@ -270,7 +305,8 @@ def get_related_ddl(self, question: str, **kwargs) -> List[str]: "match": { "ddl": question } - } + }, + "size": self.n_results } print(query) response = self.client.search(index=self.ddl_index, body=query, @@ -283,7 +319,8 @@ def get_related_documentation(self, question: str, **kwargs) -> List[str]: "match": { "doc": question } - } + }, + "size": self.n_results } print(query) response = self.client.search(index=self.document_index, @@ -297,7 +334,8 @@ def get_similar_question_sql(self, question: str, **kwargs) -> List[str]: "match": { "question": question } - } + }, + "size": self.n_results } print(query) response = self.client.search(index=self.question_sql_index, @@ -306,6 +344,50 @@ def get_similar_question_sql(self, question: str, **kwargs) -> List[str]: return [(hit['_source']['question'], hit['_source']['sql']) for hit in response['hits']['hits']] + def search_tables_metadata(self, + engine: str = None, + catalog: str = None, + schema: str = None, + table_name: str = None, + ddl: str = None, + size: int = 10, + **kwargs) -> list: + # Assume you have some vector search mechanism associated with your data + query = {} + if engine is None and catalog is None and schema is None and table_name is None and ddl is None: + query = { + "query": { + "match_all": {} + } + } + else: + query["query"] = { + "bool": { + "should": [ + ] + } + } + if engine is not None: + query["query"]["bool"]["should"].append({"match": {"engine": engine}}) + + if catalog is not None: + query["query"]["bool"]["should"].append({"match": {"catalog": catalog}}) + + if schema is not None: + query["query"]["bool"]["should"].append({"match": {"schema": schema}}) + if table_name is not None: + query["query"]["bool"]["should"].append({"match": {"table_name": table_name}}) + + if ddl is not None: + query["query"]["bool"]["should"].append({"match": {"ddl": ddl}}) + + if size > 0: + query["size"] = size + + print(query) + response = self.client.search(index=self.ddl_index, body=query, **kwargs) + return [hit['_source'] for hit in response['hits']['hits']] + def get_training_data(self, **kwargs) -> pd.DataFrame: # This will be a simple example pulling all data from an index # WARNING: Do not use this approach in production for large indices! @@ -315,7 +397,6 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: body={"query": {"match_all": {}}}, size=1000 ) - print(query) # records = [hit['_source'] for hit in response['hits']['hits']] for hit in response['hits']['hits']: data.append( diff --git a/src/vanna/types/__init__.py b/src/vanna/types/__init__.py index f3841c88..c1904c3e 100644 --- a/src/vanna/types/__init__.py +++ b/src/vanna/types/__init__.py @@ -290,3 +290,29 @@ def remove_item(self, item: str): if str(plan_item) == item: self._plan.remove(plan_item) break + + +class TableMetadata: + def __init__(self, catalog=None, schema=None, table_name=None): + self.catalog = catalog + self.schema = schema + self.table_name = table_name + + def __str__(self): + parts = [] + if self.catalog: + parts.append(f"Catalog: {self.catalog}") + if self.schema: + parts.append(f"Schema: {self.schema}") + if self.table_name: + parts.append(f"Table: {self.table_name}") + return "\n".join(parts) if parts else "No match found" + + def get_full_table_name(self): + if self.catalog and self.schema: + return f"{self.catalog}.{self.schema}.{self.table_name}" + elif self.schema: + return f"{self.schema}.{self.table_name}" + else: + return f"{self.table_name}" +