Skip to content

Commit

Permalink
Merge pull request #463 from zyclove/opensearch_fix
Browse files Browse the repository at this point in the history
【feat】add database engine and table name to support table ddl update
  • Loading branch information
zainhoda authored Nov 21, 2024
2 parents cd29916 + 54a4f8e commit ad014c4
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 10 deletions.
86 changes: 82 additions & 4 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import sqlparse

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
from ..types import TrainingPlan, TrainingPlanItem, TableMetadata
from ..utils import validate_config_path


Expand Down Expand Up @@ -210,6 +210,54 @@ def extract_sql(self, llm_response: str) -> str:

return llm_response

def extract_table_metadata(ddl: str) -> TableMetadata:
"""
Example:
```python
vn.extract_table_metadata("CREATE TABLE hive.bi_ads.customers (id INT, name TEXT, sales DECIMAL)")
```
Extracts the table metadata from a DDL statement. This is useful in case the DDL statement contains other information besides the table metadata.
Override this function if your DDL statements need custom extraction logic.
Args:
ddl (str): The DDL statement.
Returns:
TableMetadata: The extracted table metadata.
"""
pattern_with_catalog_schema = re.compile(
r'CREATE TABLE\s+(\w+)\.(\w+)\.(\w+)\s*\(',
re.IGNORECASE
)
pattern_with_schema = re.compile(
r'CREATE TABLE\s+(\w+)\.(\w+)\s*\(',
re.IGNORECASE
)
pattern_with_table = re.compile(
r'CREATE TABLE\s+(\w+)\s*\(',
re.IGNORECASE
)

match_with_catalog_schema = pattern_with_catalog_schema.search(ddl)
match_with_schema = pattern_with_schema.search(ddl)
match_with_table = pattern_with_table.search(ddl)

if match_with_catalog_schema:
catalog = match_with_catalog_schema.group(1)
schema = match_with_catalog_schema.group(2)
table_name = match_with_catalog_schema.group(3)
return TableMetadata(catalog, schema, table_name)
elif match_with_schema:
schema = match_with_schema.group(1)
table_name = match_with_schema.group(2)
return TableMetadata(None, schema, table_name)
elif match_with_table:
table_name = match_with_table.group(1)
return TableMetadata(None, None, table_name)
else:
return TableMetadata()

def is_sql_valid(self, sql: str) -> bool:
"""
Example:
Expand Down Expand Up @@ -395,6 +443,31 @@ def get_related_ddl(self, question: str, **kwargs) -> list:
"""
pass

@abstractmethod
def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
size: int = 10,
**kwargs) -> list:
"""
This method is used to get similar tables metadata.
Args:
engine (str): The database engine.
catalog (str): The catalog.
schema (str): The schema.
table_name (str): The table name.
ddl (str): The DDL statement.
size (int): The number of tables to return.
Returns:
list: A list of tables metadata.
"""
pass

@abstractmethod
def get_related_documentation(self, question: str, **kwargs) -> list:
"""
Expand Down Expand Up @@ -423,12 +496,13 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
pass

@abstractmethod
def add_ddl(self, ddl: str, **kwargs) -> str:
def add_ddl(self, ddl: str, engine: str = None, **kwargs) -> str:
"""
This method is used to add a DDL statement to the training data.
Args:
ddl (str): The DDL statement to add.
engine (str): The database engine that the DDL statement applies to.
Returns:
str: The ID of the training data that was added.
Expand Down Expand Up @@ -1778,6 +1852,7 @@ def train(
question: str = None,
sql: str = None,
ddl: str = None,
engine: str = None,
documentation: str = None,
plan: TrainingPlan = None,
) -> str:
Expand All @@ -1798,8 +1873,11 @@ def train(
question (str): The question to train on.
sql (str): The SQL query to train on.
ddl (str): The DDL statement.
engine (str): The database engine.
documentation (str): The documentation to train on.
plan (TrainingPlan): The training plan to train on.
Returns:
str: The training pl
"""

if question and not sql:
Expand All @@ -1817,12 +1895,12 @@ def train(

if ddl:
print("Adding ddl:", ddl)
return self.add_ddl(ddl)
return self.add_ddl(ddl=ddl, engine=engine)

if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
self.add_ddl(item.item_value)
self.add_ddl(ddl=item.item_value, engine=engine)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
self.add_documentation(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
Expand Down
93 changes: 87 additions & 6 deletions src/vanna/opensearch/opensearch_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import pandas as pd
from opensearchpy import OpenSearch
from ..types import TableMetadata

from ..base import VannaBase
from ..utils import deterministic_uuid


class OpenSearch_VectorStore(VannaBase):
Expand Down Expand Up @@ -56,6 +58,18 @@ def __init__(self, config=None):
},
"mappings": {
"properties": {
"engine": {
"type": "keyword",
},
"catalog": {
"type": "keyword",
},
"schema": {
"type": "keyword",
},
"table_name": {
"type": "keyword",
},
"ddl": {
"type": "text",
},
Expand Down Expand Up @@ -92,6 +106,8 @@ def __init__(self, config=None):
if config is not None and "es_question_sql_index_settings" in config:
question_sql_index_settings = config["es_question_sql_index_settings"]

self.n_results = config.get("n_results", 10)

self.document_index_settings = document_index_settings
self.ddl_index_settings = ddl_index_settings
self.question_sql_index_settings = question_sql_index_settings
Expand Down Expand Up @@ -231,10 +247,29 @@ def create_index_if_not_exists(self, index_name: str,
print(f"Error creating index: {index_name} ", e)
return False

def add_ddl(self, ddl: str, **kwargs) -> str:
def calculate_md5(self, string: str) -> str:
# 将字符串编码为 bytes
string_bytes = self.encode('utf-8')
# 计算 MD5 哈希值
md5_hash = hashlib.md5(string_bytes)
# 获取十六进制表示的哈希值
md5_hex = md5_hash.hexdigest()
return md5_hex

def add_ddl(self, ddl: str, engine: str = None,
**kwargs) -> str:
# Assuming that you have a DDL index in your OpenSearch
id = str(uuid.uuid4()) + "-ddl"
table_metadata = VannaBase.extract_table_metadata(ddl)
full_table_name = table_metadata.get_full_table_name()
if full_table_name is not None and engine is not None:
id = deterministic_uuid(engine + "-" + full_table_name) + "-ddl"
else:
id = str(uuid.uuid4()) + "-ddl"
ddl_dict = {
"engine": engine,
"catalog": table_metadata.catalog,
"schema": table_metadata.schema,
"table_name": table_metadata.table_name,
"ddl": ddl
}
response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id,
Expand Down Expand Up @@ -270,7 +305,8 @@ def get_related_ddl(self, question: str, **kwargs) -> List[str]:
"match": {
"ddl": question
}
}
},
"size": self.n_results
}
print(query)
response = self.client.search(index=self.ddl_index, body=query,
Expand All @@ -283,7 +319,8 @@ def get_related_documentation(self, question: str, **kwargs) -> List[str]:
"match": {
"doc": question
}
}
},
"size": self.n_results
}
print(query)
response = self.client.search(index=self.document_index,
Expand All @@ -297,7 +334,8 @@ def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
"match": {
"question": question
}
}
},
"size": self.n_results
}
print(query)
response = self.client.search(index=self.question_sql_index,
Expand All @@ -306,6 +344,50 @@ def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
return [(hit['_source']['question'], hit['_source']['sql']) for hit in
response['hits']['hits']]

def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
size: int = 10,
**kwargs) -> list:
# Assume you have some vector search mechanism associated with your data
query = {}
if engine is None and catalog is None and schema is None and table_name is None and ddl is None:
query = {
"query": {
"match_all": {}
}
}
else:
query["query"] = {
"bool": {
"should": [
]
}
}
if engine is not None:
query["query"]["bool"]["should"].append({"match": {"engine": engine}})

if catalog is not None:
query["query"]["bool"]["should"].append({"match": {"catalog": catalog}})

if schema is not None:
query["query"]["bool"]["should"].append({"match": {"schema": schema}})
if table_name is not None:
query["query"]["bool"]["should"].append({"match": {"table_name": table_name}})

if ddl is not None:
query["query"]["bool"]["should"].append({"match": {"ddl": ddl}})

if size > 0:
query["size"] = size

print(query)
response = self.client.search(index=self.ddl_index, body=query, **kwargs)
return [hit['_source'] for hit in response['hits']['hits']]

def get_training_data(self, **kwargs) -> pd.DataFrame:
# This will be a simple example pulling all data from an index
# WARNING: Do not use this approach in production for large indices!
Expand All @@ -315,7 +397,6 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:
body={"query": {"match_all": {}}},
size=1000
)
print(query)
# records = [hit['_source'] for hit in response['hits']['hits']]
for hit in response['hits']['hits']:
data.append(
Expand Down
26 changes: 26 additions & 0 deletions src/vanna/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,29 @@ def remove_item(self, item: str):
if str(plan_item) == item:
self._plan.remove(plan_item)
break


class TableMetadata:
def __init__(self, catalog=None, schema=None, table_name=None):
self.catalog = catalog
self.schema = schema
self.table_name = table_name

def __str__(self):
parts = []
if self.catalog:
parts.append(f"Catalog: {self.catalog}")
if self.schema:
parts.append(f"Schema: {self.schema}")
if self.table_name:
parts.append(f"Table: {self.table_name}")
return "\n".join(parts) if parts else "No match found"

def get_full_table_name(self):
if self.catalog and self.schema:
return f"{self.catalog}.{self.schema}.{self.table_name}"
elif self.schema:
return f"{self.schema}.{self.table_name}"
else:
return f"{self.table_name}"

0 comments on commit ad014c4

Please sign in to comment.