From 1fdc77a9cefdbea25effba05071171a1af220bbd Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Mon, 8 Jul 2024 21:01:14 +0800 Subject: [PATCH 01/30] feat: openrouter, qdrant async --- .gitignore | 1 + pyproject.toml | 51 +- src/vanna/base/base.py | 1022 ++++++++++++++++++----- src/vanna/openrouter/__init__.py | 3 + src/vanna/openrouter/openrouter_chat.py | 179 ++++ src/vanna/qdrant/qdrant.py | 273 +++++- 6 files changed, 1297 insertions(+), 232 deletions(-) create mode 100644 src/vanna/openrouter/__init__.py create mode 100644 src/vanna/openrouter/openrouter_chat.py diff --git a/.gitignore b/.gitignore index 69698f91..cc3cdd37 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ htmlcov chroma.sqlite3 *.bin .coverage.* +.vscode \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index bfbc7f3d..f1c94339 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,20 +5,26 @@ build-backend = "flit_core.buildapi" [project] name = "vanna" version = "0.6.2" -authors = [ - { name="Zain Hoda", email="zain@vanna.ai" }, -] +authors = [{ name = "Zain Hoda", email = "zain@vanna.ai" }] description = "Generate SQL queries from natural language" readme = "README.md" requires-python = ">=3.9" classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", ] dependencies = [ - "requests", "tabulate", "plotly", "pandas", "sqlparse", "kaleido", "flask", "flask-sock", "sqlalchemy" + "requests", + "tabulate", + "plotly", + "pandas", + "sqlparse", + "kaleido", + "flask", + "flask-sock", + "sqlalchemy", ] [project.urls] @@ -27,13 +33,38 @@ dependencies = [ [project.optional-dependencies] postgres = ["psycopg2-binary", "db-dtypes"] -mysql = ["PyMySQL"] +mysql = ["PyMySQL", "aiomysql"] clickhouse = ["clickhouse_connect"] bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] google = ["google-generativeai", "google-cloud-aiplatform"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]"] +all = [ + "psycopg2-binary", + "db-dtypes", + "PyMySQL", + "aiomysql", + "google-cloud-bigquery", + "snowflake-connector-python", + "duckdb", + "openai", + "mistralai", + "chromadb", + "anthropic", + "zhipuai", + "marqo", + "google-generativeai", + "google-cloud-aiplatform", + "qdrant-client", + "fastembed", + "ollama", + "httpx", + "opensearch-py", + "opensearch-dsl", + "transformers", + "pinecone-client", + "pymilvus[model]", +] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] @@ -43,7 +74,7 @@ gemini = ["google-generativeai"] marqo = ["marqo"] zhipuai = ["zhipuai"] ollama = ["ollama", "httpx"] -qdrant = ["qdrant-client", "fastembed"] +qdrant = ["qdrant-client[fastembed]"] vllm = ["vllm"] pinecone = ["pinecone-client", "fastembed"] opensearch = ["opensearch-py", "opensearch-dsl"] diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 492516ea..fa7822a5 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -48,13 +48,14 @@ """ +import asyncio import json import os import re import sqlite3 import traceback from abc import ABC, abstractmethod -from typing import List, Tuple, Union +from typing import Any, List, Tuple, Union from urllib.parse import urlparse import pandas as pd @@ -76,6 +77,7 @@ def __init__(self, config=None): self.config = config self.run_sql_is_set = False + self.arun_sql_is_set = False self.static_documentation = "" self.dialect = self.config.get("dialect", "SQL") self.language = self.config.get("language", None) @@ -136,7 +138,7 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> llm_response = self.submit_prompt(prompt, **kwargs) self.log(title="LLM Response", message=llm_response) - if 'intermediate_sql' in llm_response: + if "intermediate_sql" in llm_response: if not allow_llm_to_see_data: return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this." @@ -152,7 +154,11 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> question=question, question_sql_list=question_sql_list, ddl_list=ddl_list, - doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()], + doc_list=doc_list + + [ + f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + + df.to_markdown() + ], **kwargs, ) self.log(title="Final SQL Prompt", message=prompt) @@ -161,6 +167,59 @@ def generate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs) -> except Exception as e: return f"Error running intermediate SQL: {e}" + return self.extract_sql(llm_response) + + async def agenerate_sql( + self, question: str, allow_llm_to_see_data=False, **kwargs + ) -> str: + # TODO: make it async + if self.config is not None: + initial_prompt = self.config.get("initial_prompt", None) + else: + initial_prompt = None + question_sql_list = await self.aget_similar_question_sql(question, **kwargs) + ddl_list = await self.aget_related_ddl(question, **kwargs) + doc_list = await self.aget_related_documentation(question, **kwargs) + prompt = self.get_sql_prompt( + initial_prompt=initial_prompt or "", + question=question, + question_sql_list=question_sql_list, + ddl_list=ddl_list, + doc_list=doc_list, + **kwargs, + ) + self.log(title="SQL Prompt", message=prompt) + llm_response = await self.asubmit_prompt(prompt, **kwargs) + self.log(title="LLM Response", message=llm_response) + + if "intermediate_sql" in llm_response: + if not allow_llm_to_see_data: + return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this." + + if allow_llm_to_see_data: + intermediate_sql = self.extract_sql(llm_response) + + try: + self.log(title="Running Intermediate SQL", message=intermediate_sql) + df = await self.arun_sql(intermediate_sql) + + prompt = self.get_sql_prompt( + initial_prompt=initial_prompt, + question=question, + question_sql_list=question_sql_list, + ddl_list=ddl_list, + doc_list=doc_list + + [ + f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + + df.to_markdown() + ], + **kwargs, + ) + self.log(title="Final SQL Prompt", message=prompt) + llm_response = await self.asubmit_prompt(prompt, **kwargs) + self.log(title="LLM Response", message=llm_response) + except Exception as e: + return f"Error running intermediate SQL: {e}" return self.extract_sql(llm_response) @@ -229,7 +288,7 @@ def is_sql_valid(self, sql: str) -> bool: parsed = sqlparse.parse(sql) for statement in parsed: - if statement.get_type() == 'SELECT': + if statement.get_type() == "SELECT": return True return False @@ -251,7 +310,7 @@ def should_generate_chart(self, df: pd.DataFrame) -> bool: bool: True if a chart should be generated, False otherwise. """ - if len(df) > 1 and df.select_dtypes(include=['number']).shape[1] > 0: + if len(df) > 1 and df.select_dtypes(include=["number"]).shape[1] > 0: return True return False @@ -282,8 +341,8 @@ def generate_followup_questions( f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n" ), self.user_message( - f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." + - self._response_language() + f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." + + self._response_language() ), ] @@ -292,6 +351,41 @@ def generate_followup_questions( numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE) return numbers_removed.split("\n") + async def agenerate_followup_questions( + self, question: str, sql: str, df: pd.DataFrame, n_questions: int = 5, **kwargs + ) -> list: + """ + **Example:** + ```python + vn.generate_followup_questions("What are the top 10 customers by sales?", sql, df) + ``` + + Generate a list of followup questions that you can ask Vanna.AI. + + Args: + question (str): The question that was asked. + sql (str): The LLM-generated SQL query. + df (pd.DataFrame): The results of the SQL query. + n_questions (int): Number of follow-up questions to generate. + + Returns: + list: A list of followup questions that you can ask Vanna.AI. + """ + message_log = [ + self.system_message( + f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n" + ), + self.user_message( + f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query." + + self._response_language() + ), + ] + + llm_response = await self.asubmit_prompt(message_log, **kwargs) + + numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE) + return numbers_removed.split("\n") + def generate_questions(self, **kwargs) -> List[str]: """ **Example:** @@ -305,6 +399,19 @@ def generate_questions(self, **kwargs) -> List[str]: return [q["question"] for q in question_sql] + async def agenerate_questions(self, **kwargs) -> List[str]: + """ + **Example:** + ```python + vn.generate_questions() + ``` + + Generate a list of questions that you can ask Vanna.AI. + """ + question_sql = await self.aget_similar_question_sql(question="", **kwargs) + + return [q["question"] for q in question_sql] + def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str: """ **Example:** @@ -327,8 +434,8 @@ def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str: f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n" ), self.user_message( - "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." + - self._response_language() + "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." + + self._response_language() ), ] @@ -336,11 +443,45 @@ def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str: return summary + async def agenerate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str: + """ + **Example:** + ```python + vn.generate_summary("What are the top 10 customers by sales?", df) + ``` + + Generate a summary of the results of a SQL query. + + Args: + question (str): The question that was asked. + df (pd.DataFrame): The results of the SQL query. + + Returns: + str: The summary of the results of the SQL query. + """ + + message_log = [ + self.system_message( + f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n" + ), + self.user_message( + "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." + + self._response_language() + ), + ] + + summary = await self.asubmit_prompt(message_log, **kwargs) + + return summary + # ----------------- Use Any Embeddings API ----------------- # @abstractmethod - def generate_embedding(self, data: str, **kwargs) -> List[float]: + def generate_embedding(self, data: str, **kwargs) -> list[float]: pass + async def agenerate_embedding(self, data: str, **kwargs) -> list[float]: + raise NotImplementedError + # ----------------- Use Any Database to Store and Retrieve Context ----------------- # @abstractmethod def get_similar_question_sql(self, question: str, **kwargs) -> list: @@ -355,6 +496,18 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list: """ pass + async def aget_similar_question_sql(self, question: str, **kwargs) -> list: + """ + This method is used to get similar questions and their corresponding SQL statements. + + Args: + question (str): The question to get similar questions and their corresponding SQL statements for. + + Returns: + list: A list of similar questions and their corresponding SQL statements. + """ + raise NotImplementedError + @abstractmethod def get_related_ddl(self, question: str, **kwargs) -> list: """ @@ -368,6 +521,18 @@ def get_related_ddl(self, question: str, **kwargs) -> list: """ pass + async def aget_related_ddl(self, question: str, **kwargs) -> list: + """ + This method is used to get related DDL statements to a question. + + Args: + question (str): The question to get related DDL statements for. + + Returns: + list: A list of related DDL statements. + """ + raise NotImplementedError + @abstractmethod def get_related_documentation(self, question: str, **kwargs) -> list: """ @@ -381,6 +546,18 @@ def get_related_documentation(self, question: str, **kwargs) -> list: """ pass + async def aget_related_documentation(self, question: str, **kwargs) -> list: + """ + This method is used to get related documentation to a question. + + Args: + question (str): The question to get related documentation for. + + Returns: + list: A list of related documentation. + """ + raise NotImplementedError + @abstractmethod def add_question_sql(self, question: str, sql: str, **kwargs) -> str: """ @@ -395,6 +572,19 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: """ pass + async def aadd_question_sql(self, question: str, sql: str, **kwargs) -> str: + """ + This method is used to add a question and its corresponding SQL query to the training data. + + Args: + question (str): The question to add. + sql (str): The SQL query to add. + + Returns: + str: The ID of the training data that was added. + """ + raise NotImplementedError + @abstractmethod def add_ddl(self, ddl: str, **kwargs) -> str: """ @@ -408,6 +598,18 @@ def add_ddl(self, ddl: str, **kwargs) -> str: """ pass + async def aadd_ddl(self, ddl: str, **kwargs) -> str: + """ + This method is used to add a DDL statement to the training data. + + Args: + ddl (str): The DDL statement to add. + + Returns: + str: The ID of the training data that was added. + """ + raise NotImplementedError + @abstractmethod def add_documentation(self, documentation: str, **kwargs) -> str: """ @@ -421,6 +623,18 @@ def add_documentation(self, documentation: str, **kwargs) -> str: """ pass + async def aadd_documentation(self, documentation: str, **kwargs) -> str: + """ + This method is used to add documentation to the training data. + + Args: + documentation (str): The documentation to add. + + Returns: + str: The ID of the training data that was added. + """ + raise NotImplementedError + @abstractmethod def get_training_data(self, **kwargs) -> pd.DataFrame: """ @@ -436,8 +650,22 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: """ pass + async def aget_training_data(self, **kwargs) -> pd.DataFrame: + """ + Example: + ```python + vn.get_training_data() + ``` + + This method is used to get all the training data from the retrieval layer. + + Returns: + pd.DataFrame: The training data. + """ + raise NotImplementedError + @abstractmethod - def remove_training_data(id: str, **kwargs) -> bool: + def remove_training_data(self, id: str, **kwargs) -> bool: """ Example: ```python @@ -454,22 +682,39 @@ def remove_training_data(id: str, **kwargs) -> bool: """ pass + async def aremove_training_data(self, id: str, **kwargs) -> bool: + """ + Example: + ```python + vn.remove_training_data(id="123-ddl") + ``` + + This method is used to remove training data from the retrieval layer. + + Args: + id (str): The ID of the training data to remove. + + Returns: + bool: True if the training data was removed, False otherwise. + """ + raise NotImplementedError + # ----------------- Use Any Language Model API ----------------- # @abstractmethod - def system_message(self, message: str) -> any: + def system_message(self, message: str) -> Any: pass @abstractmethod - def user_message(self, message: str) -> any: + def user_message(self, message: str) -> Any: pass @abstractmethod - def assistant_message(self, message: str) -> any: + def assistant_message(self, message: str) -> Any: pass def str_to_approx_token_count(self, string: str) -> int: - return len(string) / 4 + return int(len(string) / 4) def add_ddl_to_prompt( self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000 @@ -524,7 +769,7 @@ def add_sql_to_prompt( def get_sql_prompt( self, - initial_prompt : str, + initial_prompt: str, question: str, question_sql_list: list, ddl_list: list, @@ -556,8 +801,10 @@ def get_sql_prompt( """ if initial_prompt is None: - initial_prompt = f"You are a {self.dialect} expert. " + \ - "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. " + initial_prompt = ( + f"You are a {self.dialect} expert. " + + "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. " + ) initial_prompt = self.add_ddl_to_prompt( initial_prompt, ddl_list, max_tokens=self.max_tokens @@ -647,6 +894,29 @@ def submit_prompt(self, prompt, **kwargs) -> str: """ pass + @abstractmethod + async def asubmit_prompt(self, prompt, **kwargs) -> str: + """ + Example: + ```python + vn.submit_prompt( + [ + vn.system_message("The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."), + vn.user_message("What are the top 10 customers by sales?"), + ] + ) + ``` + + This method is used to submit a prompt to the LLM. + + Args: + prompt (any): The prompt to submit to the LLM. + + Returns: + str: The response from the LLM. + """ + pass + def generate_question(self, sql: str, **kwargs) -> str: response = self.submit_prompt( [ @@ -660,6 +930,19 @@ def generate_question(self, sql: str, **kwargs) -> str: return response + async def agenerate_question(self, sql: str, **kwargs) -> str: + response = await self.asubmit_prompt( + [ + self.system_message( + "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question." + ), + self.user_message(sql), + ], + **kwargs, + ) + + return response + def _extract_python_code(self, markdown_string: str) -> str: # Regex pattern to match Python code blocks pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" @@ -685,7 +968,11 @@ def _sanitize_plotly_code(self, raw_plotly_code: str) -> str: return plotly_code def generate_plotly_code( - self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs + self, + question: str | None = None, + sql: str | None = None, + df_metadata: str | None = None, + **kwargs, ) -> str: if question is not None: system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'" @@ -708,6 +995,34 @@ def generate_plotly_code( return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) + async def agenerate_plotly_code( + self, + question: str | None = None, + sql: str | None = None, + df_metadata: str | None = None, + **kwargs, + ) -> str: + if question is not None: + system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'" + else: + system_msg = "The following is a pandas DataFrame " + + if sql is not None: + system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n" + + system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}" + + message_log = [ + self.system_message(system_msg), + self.user_message( + "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code." + ), + ] + + plotly_code = await self.asubmit_prompt(message_log, kwargs=kwargs) + + return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) + # ----------------- Connect to Any Database to run the Generated SQL ----------------- # def connect_to_snowflake( @@ -764,7 +1079,7 @@ def connect_to_snowflake( password=password, account=account, database=database, - client_session_keep_alive=True + client_session_keep_alive=True, ) def run_sql_snowflake(sql: str) -> pd.DataFrame: @@ -929,14 +1244,13 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_postgres - def connect_to_mysql( - self, - host: str = None, - dbname: str = None, - user: str = None, - password: str = None, - port: int = None, + self, + host: str = None, + dbname: str = None, + user: str = None, + password: str = None, + port: int = None, ): try: @@ -980,12 +1294,14 @@ def connect_to_mysql( conn = None try: - conn = pymysql.connect(host=host, - user=user, - password=password, - database=dbname, - port=port, - cursorclass=pymysql.cursors.DictCursor) + conn = pymysql.connect( + host=host, + user=user, + password=password, + database=dbname, + port=port, + cursorclass=pymysql.cursors.DictCursor, + ) except pymysql.Error as e: raise ValidationError(e) @@ -1014,13 +1330,86 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_mysql + async def aconnect_to_mysql( + self, + host: str | None = None, + dbname: str | None = None, + user: str | None = None, + password: str | None = None, + port: int | None = None, + ): + + try: + from sqlalchemy.ext.asyncio import create_async_engine + except ImportError: + raise DependencyError( + "You need to install required dependencies to execute this method," + " run command: \npip install aiomysql" + ) + + if not host: + host = os.getenv("HOST") + + if not host: + raise ImproperlyConfigured("Please set your MySQL host") + + if not dbname: + dbname = os.getenv("DATABASE") + + if not dbname: + raise ImproperlyConfigured("Please set your MySQL database") + + if not user: + user = os.getenv("USER") + + if not user: + raise ImproperlyConfigured("Please set your MySQL user") + + if not password: + password = os.getenv("PASSWORD") + + if not password: + raise ImproperlyConfigured("Please set your MySQL password") + + if not port: + port = int(os.getenv("PORT", 3306)) + + conn = None + + try: + engine = create_async_engine( + url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}" + ) + conn = await engine.connect() + except Exception as e: + raise ValidationError(e) + + async def arun_sql_mysql(sql: str, **kwargs) -> pd.DataFrame: + from sqlalchemy import text + + try: + cs = await conn.execute(text(sql)) + results = cs.fetchall() + + columns = cs.keys() + # Create a pandas dataframe from the results + df = pd.DataFrame(results, columns=columns) # type: ignore + return df + + except Exception as e: + await conn.rollback() + raise e + + self.arun_sql_is_set = True + self.arun_sql = arun_sql_mysql + def connect_to_clickhouse( - self, - host: str = None, - dbname: str = None, - user: str = None, - password: str = None, - port: int = None, + self, + host: str = None, + dbname: str = None, + user: str = None, + password: str = None, + port: int = None, ): try: @@ -1087,17 +1476,16 @@ def run_sql_clickhouse(sql: str) -> Union[pd.DataFrame, None]: except Exception as e: raise e - + self.run_sql_is_set = True self.run_sql = run_sql_clickhouse def connect_to_oracle( - self, - user: str = None, - password: str = None, - dsn: str = None, + self, + user: str = None, + password: str = None, + dsn: str = None, ): - """ Connect to an Oracle db using oracledb package. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] **Example:** @@ -1127,7 +1515,9 @@ def connect_to_oracle( dsn = os.getenv("DSN") if not dsn: - raise ImproperlyConfigured("Please set your Oracle dsn which should include host:port/sid") + raise ImproperlyConfigured( + "Please set your Oracle dsn which should include host:port/sid" + ) if not user: user = os.getenv("USER") @@ -1148,7 +1538,7 @@ def connect_to_oracle( user=user, password=password, dsn=dsn, - ) + ) except oracledb.Error as e: raise ValidationError(e) @@ -1156,7 +1546,9 @@ def run_sql_oracle(sql: str) -> Union[pd.DataFrame, None]: if conn: try: sql = sql.rstrip() - if sql.endswith(';'): #fix for a known problem with Oracle db where an extra ; will cause an error. + if sql.endswith( + ";" + ): # fix for a known problem with Oracle db where an extra ; will cause an error. sql = sql[:-1] cs = conn.cursor() @@ -1361,19 +1753,20 @@ def run_sql_mssql(sql: str): self.dialect = "T-SQL / Microsoft SQL Server" self.run_sql = run_sql_mssql self.run_sql_is_set = True + def connect_to_presto( - self, - host: str, - catalog: str = 'hive', - schema: str = 'default', - user: str = None, - password: str = None, - port: int = None, - combined_pem_path: str = None, - protocol: str = 'https', - requests_kwargs: dict = None + self, + host: str, + catalog: str = "hive", + schema: str = "default", + user: str = None, + password: str = None, + port: int = None, + combined_pem_path: str = None, + protocol: str = "https", + requests_kwargs: dict = None, ): - """ + """ Connect to a Presto database using the specified parameters. Args: @@ -1393,99 +1786,101 @@ def connect_to_presto( Returns: None - """ - try: - from pyhive import presto - except ImportError: - raise DependencyError( - "You need to install required dependencies to execute this method," - " run command: \npip install pyhive" - ) + """ + try: + from pyhive import presto + except ImportError: + raise DependencyError( + "You need to install required dependencies to execute this method," + " run command: \npip install pyhive" + ) - if not host: - host = os.getenv("PRESTO_HOST") - - if not host: - raise ImproperlyConfigured("Please set your presto host") - - if not catalog: - catalog = os.getenv("PRESTO_CATALOG") - - if not catalog: - raise ImproperlyConfigured("Please set your presto catalog") - - if not user: - user = os.getenv("PRESTO_USER") - - if not user: - raise ImproperlyConfigured("Please set your presto user") - - if not password: - password = os.getenv("PRESTO_PASSWORD") - - if not port: - port = os.getenv("PRESTO_PORT") - - if not port: - raise ImproperlyConfigured("Please set your presto port") - - conn = None - - try: - if requests_kwargs is None and combined_pem_path is not None: - # use the combined pem file to verify the SSL connection - requests_kwargs = { - 'verify': combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证 - } - conn = presto.Connection(host=host, - username=user, - password=password, - catalog=catalog, - schema=schema, - port=port, - protocol=protocol, - requests_kwargs=requests_kwargs) - except presto.Error as e: - raise ValidationError(e) - - def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]: - if conn: - try: - sql = sql.rstrip() - # fix for a known problem with presto db where an extra ; will cause an error. - if sql.endswith(';'): - sql = sql[:-1] - cs = conn.cursor() - cs.execute(sql) - results = cs.fetchall() + if not host: + host = os.getenv("PRESTO_HOST") - # Create a pandas dataframe from the results - df = pd.DataFrame( - results, columns=[desc[0] for desc in cs.description] - ) - return df + if not host: + raise ImproperlyConfigured("Please set your presto host") - except presto.Error as e: - print(e) + if not catalog: + catalog = os.getenv("PRESTO_CATALOG") + + if not catalog: + raise ImproperlyConfigured("Please set your presto catalog") + + if not user: + user = os.getenv("PRESTO_USER") + + if not user: + raise ImproperlyConfigured("Please set your presto user") + + if not password: + password = os.getenv("PRESTO_PASSWORD") + + if not port: + port = os.getenv("PRESTO_PORT") + + if not port: + raise ImproperlyConfigured("Please set your presto port") + + conn = None + + try: + if requests_kwargs is None and combined_pem_path is not None: + # use the combined pem file to verify the SSL connection + requests_kwargs = { + "verify": combined_pem_path, # 使用转换后得到的 PEM 文件进行 SSL 验证 + } + conn = presto.Connection( + host=host, + username=user, + password=password, + catalog=catalog, + schema=schema, + port=port, + protocol=protocol, + requests_kwargs=requests_kwargs, + ) + except presto.Error as e: raise ValidationError(e) - except Exception as e: - print(e) - raise e + def run_sql_presto(sql: str) -> Union[pd.DataFrame, None]: + if conn: + try: + sql = sql.rstrip() + # fix for a known problem with presto db where an extra ; will cause an error. + if sql.endswith(";"): + sql = sql[:-1] + cs = conn.cursor() + cs.execute(sql) + results = cs.fetchall() - self.run_sql_is_set = True - self.run_sql = run_sql_presto + # Create a pandas dataframe from the results + df = pd.DataFrame( + results, columns=[desc[0] for desc in cs.description] + ) + return df + + except presto.Error as e: + print(e) + raise ValidationError(e) + + except Exception as e: + print(e) + raise e + + self.run_sql_is_set = True + self.run_sql = run_sql_presto def connect_to_hive( - self, - host: str = None, - dbname: str = 'default', - user: str = None, - password: str = None, - port: int = None, - auth: str = 'CUSTOM' + self, + host: str = None, + dbname: str = "default", + user: str = None, + password: str = None, + port: int = None, + auth: str = "CUSTOM", ): - """ + """ Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] Connect to a Hive database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] @@ -1499,78 +1894,80 @@ def connect_to_hive( Returns: None - """ - - try: - from pyhive import hive - except ImportError: - raise DependencyError( - "You need to install required dependencies to execute this method," - " run command: \npip install pyhive" - ) - - if not host: - host = os.getenv("HIVE_HOST") + """ - if not host: - raise ImproperlyConfigured("Please set your hive host") + try: + from pyhive import hive + except ImportError: + raise DependencyError( + "You need to install required dependencies to execute this method," + " run command: \npip install pyhive" + ) - if not dbname: - dbname = os.getenv("HIVE_DATABASE") + if not host: + host = os.getenv("HIVE_HOST") - if not dbname: - raise ImproperlyConfigured("Please set your hive database") + if not host: + raise ImproperlyConfigured("Please set your hive host") - if not user: - user = os.getenv("HIVE_USER") + if not dbname: + dbname = os.getenv("HIVE_DATABASE") - if not user: - raise ImproperlyConfigured("Please set your hive user") + if not dbname: + raise ImproperlyConfigured("Please set your hive database") - if not password: - password = os.getenv("HIVE_PASSWORD") + if not user: + user = os.getenv("HIVE_USER") - if not port: - port = os.getenv("HIVE_PORT") + if not user: + raise ImproperlyConfigured("Please set your hive user") - if not port: - raise ImproperlyConfigured("Please set your hive port") + if not password: + password = os.getenv("HIVE_PASSWORD") - conn = None + if not port: + port = os.getenv("HIVE_PORT") - try: - conn = hive.Connection(host=host, - username=user, - password=password, - database=dbname, - port=port, - auth=auth) - except hive.Error as e: - raise ValidationError(e) + if not port: + raise ImproperlyConfigured("Please set your hive port") - def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]: - if conn: - try: - cs = conn.cursor() - cs.execute(sql) - results = cs.fetchall() + conn = None - # Create a pandas dataframe from the results - df = pd.DataFrame( - results, columns=[desc[0] for desc in cs.description] + try: + conn = hive.Connection( + host=host, + username=user, + password=password, + database=dbname, + port=port, + auth=auth, ) - return df - - except hive.Error as e: - print(e) + except hive.Error as e: raise ValidationError(e) - except Exception as e: - print(e) - raise e + def run_sql_hive(sql: str) -> Union[pd.DataFrame, None]: + if conn: + try: + cs = conn.cursor() + cs.execute(sql) + results = cs.fetchall() - self.run_sql_is_set = True - self.run_sql = run_sql_hive + # Create a pandas dataframe from the results + df = pd.DataFrame( + results, columns=[desc[0] for desc in cs.description] + ) + return df + + except hive.Error as e: + print(e) + raise ValidationError(e) + + except Exception as e: + print(e) + raise e + + self.run_sql_is_set = True + self.run_sql = run_sql_hive def run_sql(self, sql: str, **kwargs) -> pd.DataFrame: """ @@ -1591,6 +1988,25 @@ def run_sql(self, sql: str, **kwargs) -> pd.DataFrame: "You need to connect to a database first by running vn.connect_to_snowflake(), vn.connect_to_postgres(), similar function, or manually set vn.run_sql" ) + async def arun_sql(self, sql: str, **kwargs) -> pd.DataFrame: + """ + Example: + ```python + vn.run_sql("SELECT * FROM my_table") + ``` + + Run a SQL query on the connected database. + + Args: + sql (str): The SQL query to run. + + Returns: + pd.DataFrame: The results of the SQL query. + """ + raise Exception( + "You need to connect to a database first by running vn.connect_to_snowflake(), vn.connect_to_postgres(), similar function, or manually set vn.arun_sql" + ) + def ask( self, question: Union[str, None] = None, @@ -1628,7 +2044,9 @@ def ask( question = input("Enter a question: ") try: - sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data) + sql = self.generate_sql( + question=question, allow_llm_to_see_data=allow_llm_to_see_data + ) except Exception as e: print(e) return None, None, None @@ -1704,6 +2122,121 @@ def ask( return sql, None, None return sql, df, fig + async def aask( + self, + question: Union[str, None] = None, + print_results: bool = True, + auto_train: bool = True, + visualize: bool = True, # if False, will not generate plotly code + allow_llm_to_see_data: bool = False, + ) -> Union[ + Tuple[ + Union[str, None], + Union[pd.DataFrame, None], + Union[plotly.graph_objs.Figure, None], + ], + None, + ]: + """ + **Example:** + ```python + vn.ask("What are the top 10 customers by sales?") + ``` + + Ask Vanna.AI a question and get the SQL query that answers it. + + Args: + question (str): The question to ask. + print_results (bool): Whether to print the results of the SQL query. + auto_train (bool): Whether to automatically train Vanna.AI on the question and SQL query. + visualize (bool): Whether to generate plotly code and display the plotly figure. + + Returns: + Tuple[str, pd.DataFrame, plotly.graph_objs.Figure]: The SQL query, the results of the SQL query, and the plotly figure. + """ + + if question is None: + question = input("Enter a question: ") + + try: + sql = await self.agenerate_sql( + question=question, allow_llm_to_see_data=allow_llm_to_see_data + ) + except Exception as e: + print(e) + return None, None, None + + if print_results: + try: + Code = __import__("IPython.display", fromList=["Code"]).Code + display(Code(sql)) + except Exception as e: + print(sql) + + if self.arun_sql_is_set is False: + print( + "If you want to run the SQL query, connect to a database first. See here: https://vanna.ai/docs/databases.html" + ) + + if print_results: + return None + else: + return sql, None, None + + try: + df = await self.arun_sql(sql) + + if print_results: + try: + display = __import__( + "IPython.display", fromList=["display"] + ).display + display(df) + except Exception as e: + print(df) + + if len(df) > 0 and auto_train: + await self.aadd_question_sql(question=question, sql=sql) + # Only generate plotly code if visualize is True + if visualize: + try: + plotly_code = await self.agenerate_plotly_code( + question=question, + sql=sql, + df_metadata=f"Running df.dtypes gives:\n {df.dtypes}", + ) + fig = self.get_plotly_figure(plotly_code=plotly_code, df=df) + if print_results: + try: + display = __import__( + "IPython.display", fromlist=["display"] + ).display + Image = __import__( + "IPython.display", fromlist=["Image"] + ).Image + img_bytes = fig.to_image(format="png", scale=2) + display(Image(img_bytes)) + except Exception as e: + fig.show() + except Exception as e: + # Print stack trace + traceback.print_exc() + print("Couldn't run plotly code: ", e) + if print_results: + return None + else: + return sql, df, None + else: + return sql, df, None + + except Exception as e: + print("Couldn't run sql: ", e) + if print_results: + return None + else: + return sql, None, None + return sql, df, fig + def train( self, question: str = None, @@ -1759,6 +2292,63 @@ def train( elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL: self.add_question_sql(question=item.item_name, sql=item.item_value) + async def atrain( + self, + question: str | None = None, + sql: str | None = None, + ddl: str | None = None, + documentation: str | None = None, + plan: TrainingPlan | None = None, + ) -> str: + """ + **Example:** + ```python + vn.train() + ``` + + Train Vanna.AI on a question and its corresponding SQL query. + If you call it with no arguments, it will check if you connected to a database and it will attempt to train on the metadata of that database. + If you call it with the sql argument, it's equivalent to [`vn.add_question_sql()`][vanna.base.base.VannaBase.add_question_sql]. + If you call it with the ddl argument, it's equivalent to [`vn.add_ddl()`][vanna.base.base.VannaBase.add_ddl]. + If you call it with the documentation argument, it's equivalent to [`vn.add_documentation()`][vanna.base.base.VannaBase.add_documentation]. + Additionally, you can pass a [`TrainingPlan`][vanna.types.TrainingPlan] object. Get a training plan with [`vn.get_training_plan_generic()`][vanna.base.base.VannaBase.get_training_plan_generic]. + + Args: + question (str): The question to train on. + sql (str): The SQL query to train on. + ddl (str): The DDL statement. + documentation (str): The documentation to train on. + plan (TrainingPlan): The training plan to train on. + """ + + if question and not sql: + raise ValidationError("Please also provide a SQL query") + + if documentation: + print("Adding documentation....") + return await self.aadd_documentation(documentation) + + if sql: + if question is None: + question = await self.agenerate_question(sql) + print("Question generated with sql:", question, "\nAdding SQL...") + return await self.aadd_question_sql(question=question, sql=sql) + + if ddl: + print("Adding ddl:", ddl) + return await self.aadd_ddl(ddl) + + if plan: + for item in plan._plan: + if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: + await self.aadd_ddl(item.item_value) + elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS: + await self.aadd_documentation(item.item_value) + elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL: + await self.aadd_question_sql( + question=item.item_name, sql=item.item_value + ) + def _get_databases(self) -> List[str]: try: print("Trying INFORMATION_SCHEMA.DATABASES") @@ -1774,11 +2364,35 @@ def _get_databases(self) -> List[str]: return df_databases["DATABASE_NAME"].unique().tolist() + async def _aget_databases(self) -> List[str]: + try: + print("Trying INFORMATION_SCHEMA.DATABASES") + df_databases = await self.arun_sql( + "SELECT * FROM INFORMATION_SCHEMA.DATABASES" + ) + except Exception as e: + print(e) + try: + print("Trying SHOW DATABASES") + df_databases = await self.arun_sql("SHOW DATABASES") + except Exception as e: + print(e) + return [] + + return df_databases["DATABASE_NAME"].unique().tolist() + def _get_information_schema_tables(self, database: str) -> pd.DataFrame: df_tables = self.run_sql(f"SELECT * FROM {database}.INFORMATION_SCHEMA.TABLES") return df_tables + async def _aget_information_schema_tables(self, database: str) -> pd.DataFrame: + df_tables = await self.arun_sql( + f"SELECT * FROM {database}.INFORMATION_SCHEMA.TABLES" + ) + + return df_tables + def get_training_plan_generic(self, df) -> TrainingPlan: """ This method is used to generate a training plan from an information schema dataframe. @@ -1802,12 +2416,8 @@ def get_training_plan_generic(self, df) -> TrainingPlan: table_column = df.columns[ df.columns.str.lower().str.contains("table_name") ].to_list()[0] - columns = [database_column, - schema_column, - table_column] - candidates = ["column_name", - "data_type", - "comment"] + columns = [database_column, schema_column, table_column] + candidates = ["column_name", "data_type", "comment"] matches = df.columns.str.lower().str.contains("|".join(candidates), regex=True) columns += df.columns[matches].to_list() diff --git a/src/vanna/openrouter/__init__.py b/src/vanna/openrouter/__init__.py new file mode 100644 index 00000000..27af0590 --- /dev/null +++ b/src/vanna/openrouter/__init__.py @@ -0,0 +1,3 @@ +from .openrouter_chat import OpenRouter_Chat + +__all__ = ["OpenRouter_Chat"] diff --git a/src/vanna/openrouter/openrouter_chat.py b/src/vanna/openrouter/openrouter_chat.py new file mode 100644 index 00000000..8ed4eab6 --- /dev/null +++ b/src/vanna/openrouter/openrouter_chat.py @@ -0,0 +1,179 @@ +import os +from typing import Any, AsyncIterable + +from openai import AsyncOpenAI, OpenAI +from openai.types.chat.chat_completion import Choice + +from ..base import VannaBase + + +class OpenRouter_Chat(VannaBase): + def __init__( + self, + client=None, + aclient=None, + config: dict[str, Any] | None = None, + ): + VannaBase.__init__(self, config=config) + # default parameters - can be overrided using config + self.temperature = 0.7 + self.max_tokens = 500 + + if "temperature" in config: + self.temperature = config["temperature"] + + if "max_tokens" in config: + self.max_tokens = config["max_tokens"] + + if "api_type" in config: + raise Exception( + "Passing api_type is now deprecated. Please pass an OpenAI client instead." + ) + + if "api_base" in config: + raise Exception( + "Passing api_base is now deprecated. Please pass an OpenAI client instead." + ) + + if "api_version" in config: + raise Exception( + "Passing api_version is now deprecated. Please pass an OpenAI client instead." + ) + + if client is not None: + self.client = client + return + + if config is None and client is None: + self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + return + + if aclient is not None: + self.aclient = aclient + return + + if config is None and aclient is None: + self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) + return + + if "api_key" in config: + self.client = OpenAI(api_key=config["api_key"]) + self.aclient = AsyncOpenAI(api_key=config["api_key"]) + + def system_message(self, message: str) -> Any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> Any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> Any: + return {"role": "assistant", "content": message} + + def submit_prompt( + self, prompt, model: str = "deepseek/deepseek-chat", **kwargs + ) -> str: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + # Count the number of tokens in the message log + # Use 4 as an approximation for the number of characters per token + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + + print(f"Using model {model} for {num_tokens} tokens (approx)") + response = self.client.chat.completions.create( + model=model, + messages=prompt, + max_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + + # Find the first response from the chatbot that has text in it (some responses may not have text) + for choice in response.choices: + if not isinstance(choice, Choice): + return str(choice.text) + # If no response with text is found, return the first response's content (which may be empty) + elif ( + isinstance(choice, Choice) + and choice.message is not None + and choice.message.content is not None + ): + return choice.message.content + else: + return "" + + return "" + + async def asubmit_prompt( + self, prompt, model: str = "deepseek/deepseek-chat", **kwargs + ) -> str: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + # Count the number of tokens in the message log + # Use 4 as an approximation for the number of characters per token + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + + print(f"Using model {model} for {num_tokens} tokens (approx)") + response = await self.aclient.chat.completions.create( + model=model, + messages=prompt, + max_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + + # Find the first response from the chatbot that has text in it (some responses may not have text) + for choice in response.choices: + if not isinstance(choice, Choice): + return str(choice.text) + # If no response with text is found, return the first response's content (which may be empty) + elif ( + isinstance(choice, Choice) + and choice.message is not None + and choice.message.content is not None + ): + return choice.message.content + else: + return "" + + return "" + + async def astream_submit_prompt( + self, prompt, model: str = "deepseek/deepseek-chat", **kwargs + ) -> AsyncIterable[str]: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + # Count the number of tokens in the message log + # Use 4 as an approximation for the number of characters per token + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + + print(f"Using model {model} for {num_tokens} tokens (approx)") + stream = await self.aclient.chat.completions.create( + model=model, + messages=prompt, + max_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + stream=True, + ) + + async for chunk in stream: + if chunk.choices[0].delta is not None: + yield chunk.choices[0].delta.content or "" diff --git a/src/vanna/qdrant/qdrant.py b/src/vanna/qdrant/qdrant.py index a6db200e..0ea1811e 100644 --- a/src/vanna/qdrant/qdrant.py +++ b/src/vanna/qdrant/qdrant.py @@ -1,9 +1,9 @@ +import asyncio from functools import cached_property from typing import List, Tuple import pandas as pd -from qdrant_client import QdrantClient, grpc, models -from qdrant_client.http.models.models import UpdateStatus +from qdrant_client import AsyncQdrantClient, QdrantClient, grpc, models from ..base import VannaBase from ..utils import deterministic_uuid @@ -44,7 +44,11 @@ def __init__( config={}, ): VannaBase.__init__(self, config=config) - client = config.get("client") + client = config.get("client", None) + async_client = config.get("async_client", None) + + self._client = client + self._async_client = async_client if client is None: self._client = QdrantClient( @@ -57,14 +61,18 @@ def __init__( path=config.get("path", None), prefix=config.get("prefix", None), ) - elif not isinstance(client, QdrantClient): - raise TypeError( - f"Unsupported client of type {client.__class__} was set in config" + if async_client is None: + self._async_client = AsyncQdrantClient( + location=config.get("location", None), + url=config.get("url", None), + prefer_grpc=config.get("prefer_grpc", False), + https=config.get("https", None), + api_key=config.get("api_key", None), + timeout=config.get("timeout", None), + path=config.get("path", None), + prefix=config.get("prefix", None), ) - else: - self._client = client - self.n_results = config.get("n_results", 10) self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5") self.collection_params = config.get("collection_params", {}) @@ -72,12 +80,8 @@ def __init__( self.documentation_collection_name = config.get( "documentation_collection_name", "documentation" ) - self.ddl_collection_name = config.get( - "ddl_collection_name", "ddl" - ) - self.sql_collection_name = config.get( - "sql_collection_name", "sql" - ) + self.ddl_collection_name = config.get("ddl_collection_name", "ddl") + self.sql_collection_name = config.get("sql_collection_name", "sql") self.id_suffixes = { self.ddl_collection_name: "ddl", @@ -85,7 +89,10 @@ def __init__( self.sql_collection_name: "sql", } - self._setup_collections() + if async_client: + asyncio.run(self._asetup_collections()) + else: + self._setup_collections() def add_question_sql(self, question: str, sql: str, **kwargs) -> str: question_answer = "Question: {0}\n\nSQL: {1}".format(question, sql) @@ -107,6 +114,26 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: return self._format_point_id(id, self.sql_collection_name) + async def aadd_question_sql(self, question: str, sql: str, **kwargs) -> str: + question_answer = "Question: {0}\n\nSQL: {1}".format(question, sql) + id = deterministic_uuid(question_answer) + + await self._async_client.upsert( + self.sql_collection_name, + points=[ + models.PointStruct( + id=id, + vector=await self.agenerate_embedding(question_answer), + payload={ + "question": question, + "sql": sql, + }, + ) + ], + ) + + return self._format_point_id(id, self.sql_collection_name) + def add_ddl(self, ddl: str, **kwargs) -> str: id = deterministic_uuid(ddl) self._client.upsert( @@ -123,6 +150,22 @@ def add_ddl(self, ddl: str, **kwargs) -> str: ) return self._format_point_id(id, self.ddl_collection_name) + async def aadd_ddl(self, ddl: str, **kwargs) -> str: + id = deterministic_uuid(ddl) + await self._async_client.upsert( + self.ddl_collection_name, + points=[ + models.PointStruct( + id=id, + vector=await self.agenerate_embedding(ddl), + payload={ + "ddl": ddl, + }, + ) + ], + ) + return self._format_point_id(id, self.ddl_collection_name) + def add_documentation(self, documentation: str, **kwargs) -> str: id = deterministic_uuid(documentation) @@ -141,6 +184,24 @@ def add_documentation(self, documentation: str, **kwargs) -> str: return self._format_point_id(id, self.documentation_collection_name) + async def aadd_documentation(self, documentation: str, **kwargs) -> str: + id = deterministic_uuid(documentation) + + await self._async_client.upsert( + self.documentation_collection_name, + points=[ + models.PointStruct( + id=id, + vector=await self.agenerate_embedding(documentation), + payload={ + "documentation": documentation, + }, + ) + ], + ) + + return self._format_point_id(id, self.documentation_collection_name) + def get_training_data(self, **kwargs) -> pd.DataFrame: df = pd.DataFrame() @@ -204,6 +265,69 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: return df + async def aget_training_data(self, **kwargs) -> pd.DataFrame: + df = pd.DataFrame() + + if sql_data := await self._aget_all_points(self.sql_collection_name): + question_list = [data.payload["question"] for data in sql_data] + sql_list = [data.payload["sql"] for data in sql_data] + id_list = [ + self._format_point_id(data.id, self.sql_collection_name) + for data in sql_data + ] + + df_sql = pd.DataFrame( + { + "id": id_list, + "question": question_list, + "content": sql_list, + } + ) + + df_sql["training_data_type"] = "sql" + + df = pd.concat([df, df_sql]) + + if ddl_data := await self._aget_all_points(self.ddl_collection_name): + ddl_list = [data.payload["ddl"] for data in ddl_data] + id_list = [ + self._format_point_id(data.id, self.ddl_collection_name) + for data in ddl_data + ] + + df_ddl = pd.DataFrame( + { + "id": id_list, + "question": [None for _ in ddl_list], + "content": ddl_list, + } + ) + + df_ddl["training_data_type"] = "ddl" + + df = pd.concat([df, df_ddl]) + + if doc_data := await self._aget_all_points(self.documentation_collection_name): + document_list = [data.payload["documentation"] for data in doc_data] + id_list = [ + self._format_point_id(data.id, self.documentation_collection_name) + for data in doc_data + ] + + df_doc = pd.DataFrame( + { + "id": id_list, + "question": [None for _ in document_list], + "content": document_list, + } + ) + + df_doc["training_data_type"] = "documentation" + + df = pd.concat([df, df_doc]) + + return df + def remove_training_data(self, id: str, **kwargs) -> bool: try: id, collection_name = self._parse_point_id(id) @@ -212,6 +336,14 @@ def remove_training_data(self, id: str, **kwargs) -> bool: except ValueError: return False + async def aremove_training_data(self, id: str, **kwargs) -> bool: + try: + id, collection_name = self._parse_point_id(id) + res = await self._async_client.delete(collection_name, points_selector=[id]) + return True + except ValueError: + return False + def remove_collection(self, collection_name: str) -> bool: """ This function can reset the collection to empty state. @@ -229,6 +361,23 @@ def remove_collection(self, collection_name: str) -> bool: else: return False + async def aremove_collection(self, collection_name: str) -> bool: + """ + This function can reset the collection to empty state. + + Args: + collection_name (str): sql or ddl or documentation + + Returns: + bool: True if collection is deleted, False otherwise + """ + if collection_name in self.id_suffixes.keys(): + await self._async_client.delete_collection(collection_name) + await self._asetup_collections() + return True + else: + return False + @cached_property def embeddings_dimension(self): return len(self.generate_embedding("ABCDEF")) @@ -243,6 +392,16 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list: return [dict(result.payload) for result in results] + async def aget_similar_question_sql(self, question: str, **kwargs) -> list: + results = await self._async_client.search( + self.sql_collection_name, + query_vector=await self.agenerate_embedding(question), + limit=self.n_results, + with_payload=True, + ) + + return [dict(result.payload) for result in results] + def get_related_ddl(self, question: str, **kwargs) -> list: results = self._client.search( self.ddl_collection_name, @@ -253,6 +412,16 @@ def get_related_ddl(self, question: str, **kwargs) -> list: return [result.payload["ddl"] for result in results] + async def aget_related_ddl(self, question: str, **kwargs) -> list: + results = await self._async_client.search( + self.ddl_collection_name, + query_vector=await self.agenerate_embedding(question), + limit=self.n_results, + with_payload=True, + ) + + return [result.payload["ddl"] for result in results] + def get_related_documentation(self, question: str, **kwargs) -> list: results = self._client.search( self.documentation_collection_name, @@ -263,6 +432,16 @@ def get_related_documentation(self, question: str, **kwargs) -> list: return [result.payload["documentation"] for result in results] + async def aget_related_documentation(self, question: str, **kwargs) -> list: + results = await self._async_client.search( + self.documentation_collection_name, + query_vector=await self.agenerate_embedding(question), + limit=self.n_results, + with_payload=True, + ) + + return [result.payload["documentation"] for result in results] + def generate_embedding(self, data: str, **kwargs) -> List[float]: embedding_model = self._client._get_or_init_model( model_name=self.fastembed_model @@ -271,6 +450,14 @@ def generate_embedding(self, data: str, **kwargs) -> List[float]: return embedding.tolist() + async def agenerate_embedding(self, data: str, **kwargs) -> List[float]: + embedding_model = self._async_client._get_or_init_model( + model_name=self.fastembed_model + ) + embedding = next(embedding_model.embed(data)) + + return embedding.tolist() + def _get_all_points(self, collection_name: str): results: List[models.Record] = [] next_offset = None @@ -293,6 +480,28 @@ def _get_all_points(self, collection_name: str): return results + async def _aget_all_points(self, collection_name: str): + results: List[models.Record] = [] + next_offset = None + stop_scrolling = False + while not stop_scrolling: + records, next_offset = await self._async_client.scroll( + collection_name, + limit=SCROLL_SIZE, + offset=next_offset, + with_payload=True, + with_vectors=False, + ) + stop_scrolling = next_offset is None or ( + isinstance(next_offset, grpc.PointId) + and next_offset.num == 0 + and next_offset.uuid == "" + ) + + results.extend(records) + + return results + def _setup_collections(self): if not self._client.collection_exists(self.sql_collection_name): self._client.create_collection( @@ -323,6 +532,38 @@ def _setup_collections(self): **self.collection_params, ) + async def _asetup_collections(self): + if not await self._async_client.collection_exists(self.sql_collection_name): + await self._async_client.create_collection( + collection_name=self.sql_collection_name, + vectors_config=models.VectorParams( + size=self.embeddings_dimension, + distance=self.distance_metric, + ), + **self.collection_params, + ) + + if not await self._async_client.collection_exists(self.ddl_collection_name): + await self._async_client.create_collection( + collection_name=self.ddl_collection_name, + vectors_config=models.VectorParams( + size=self.embeddings_dimension, + distance=self.distance_metric, + ), + **self.collection_params, + ) + if not await self._async_client.collection_exists( + self.documentation_collection_name + ): + await self._async_client.create_collection( + collection_name=self.documentation_collection_name, + vectors_config=models.VectorParams( + size=self.embeddings_dimension, + distance=self.distance_metric, + ), + **self.collection_params, + ) + def _format_point_id(self, id: str, collection_name: str) -> str: return "{0}-{1}".format(id, self.id_suffixes[collection_name]) From a9e82cc6399ea89583255662e05f5f3085b464ba Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Mon, 8 Jul 2024 21:07:08 +0800 Subject: [PATCH 02/30] fix: default config to {} --- src/vanna/openrouter/openrouter_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vanna/openrouter/openrouter_chat.py b/src/vanna/openrouter/openrouter_chat.py index 8ed4eab6..f12fa4df 100644 --- a/src/vanna/openrouter/openrouter_chat.py +++ b/src/vanna/openrouter/openrouter_chat.py @@ -12,7 +12,7 @@ def __init__( self, client=None, aclient=None, - config: dict[str, Any] | None = None, + config: dict[str, Any] = {}, ): VannaBase.__init__(self, config=config) # default parameters - can be overrided using config From cc3c858dca0d79aa87c3be8afbb297f71b3015bc Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Mon, 8 Jul 2024 21:13:28 +0800 Subject: [PATCH 03/30] fix: baseurl on openrouter --- src/vanna/openrouter/openrouter_chat.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/vanna/openrouter/openrouter_chat.py b/src/vanna/openrouter/openrouter_chat.py index f12fa4df..2c4ca3e4 100644 --- a/src/vanna/openrouter/openrouter_chat.py +++ b/src/vanna/openrouter/openrouter_chat.py @@ -45,7 +45,10 @@ def __init__( return if config is None and client is None: - self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + self.client = OpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + base_url="https://openrouter.ai/api/v1", + ) return if aclient is not None: @@ -53,12 +56,19 @@ def __init__( return if config is None and aclient is None: - self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) + self.aclient = AsyncOpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + base_url="https://openrouter.ai/api/v1", + ) return if "api_key" in config: - self.client = OpenAI(api_key=config["api_key"]) - self.aclient = AsyncOpenAI(api_key=config["api_key"]) + self.client = OpenAI( + api_key=config["api_key"], base_url="https://openrouter.ai/api/v1" + ) + self.aclient = AsyncOpenAI( + api_key=config["api_key"], base_url="https://openrouter.ai/api/v1" + ) def system_message(self, message: str) -> Any: return {"role": "system", "content": message} From bb261ce3219d9dc63681a1dc9a33fabf53015797 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Mon, 8 Jul 2024 23:41:20 +0800 Subject: [PATCH 04/30] fix: increase max token for openrouter --- src/vanna/openai/openai_embeddings.py | 8 ++++++-- src/vanna/openrouter/openrouter_chat.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/vanna/openai/openai_embeddings.py b/src/vanna/openai/openai_embeddings.py index 5e5a41a6..e0420aa6 100644 --- a/src/vanna/openai/openai_embeddings.py +++ b/src/vanna/openai/openai_embeddings.py @@ -1,16 +1,20 @@ -from openai import OpenAI +from openai import AsyncOpenAI, OpenAI from ..base import VannaBase class OpenAI_Embeddings(VannaBase): - def __init__(self, client=None, config=None): + def __init__(self, client=None, async_client=None, config=None): VannaBase.__init__(self, config=config) if client is not None: self.client = client return + if async_client is not None: + self.client = async_client + return + if self.client is not None: return diff --git a/src/vanna/openrouter/openrouter_chat.py b/src/vanna/openrouter/openrouter_chat.py index 2c4ca3e4..ce03463d 100644 --- a/src/vanna/openrouter/openrouter_chat.py +++ b/src/vanna/openrouter/openrouter_chat.py @@ -16,8 +16,8 @@ def __init__( ): VannaBase.__init__(self, config=config) # default parameters - can be overrided using config - self.temperature = 0.7 - self.max_tokens = 500 + self.temperature = 0.3 + self.max_tokens = 4000 if "temperature" in config: self.temperature = config["temperature"] From 6b6a4e5792a56e5d0bf845d10864a8056c0f4924 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Mon, 8 Jul 2024 23:55:43 +0800 Subject: [PATCH 05/30] fix: embeddings dimension on qdrant --- src/vanna/qdrant/qdrant.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/vanna/qdrant/qdrant.py b/src/vanna/qdrant/qdrant.py index 0ea1811e..742a1789 100644 --- a/src/vanna/qdrant/qdrant.py +++ b/src/vanna/qdrant/qdrant.py @@ -380,7 +380,12 @@ async def aremove_collection(self, collection_name: str) -> bool: @cached_property def embeddings_dimension(self): - return len(self.generate_embedding("ABCDEF")) + if self._client: + return len(self.generate_embedding("ABCDEF")) + elif self._async_client: + import asyncio + return len(asyncio.run(self.agenerate_embedding("ABCDEF"))) + return 0 def get_similar_question_sql(self, question: str, **kwargs) -> list: results = self._client.search( From 66b986156b3812be8f33bf2376fd76506dca94b8 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Mon, 8 Jul 2024 23:59:32 +0800 Subject: [PATCH 06/30] feat: embedding dimension from async --- src/vanna/qdrant/qdrant.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/vanna/qdrant/qdrant.py b/src/vanna/qdrant/qdrant.py index 742a1789..6d45c40c 100644 --- a/src/vanna/qdrant/qdrant.py +++ b/src/vanna/qdrant/qdrant.py @@ -380,12 +380,8 @@ async def aremove_collection(self, collection_name: str) -> bool: @cached_property def embeddings_dimension(self): - if self._client: - return len(self.generate_embedding("ABCDEF")) - elif self._async_client: - import asyncio - return len(asyncio.run(self.agenerate_embedding("ABCDEF"))) - return 0 + import asyncio + return len(asyncio.run(self.agenerate_embedding("ABCDEF"))) def get_similar_question_sql(self, question: str, **kwargs) -> list: results = self._client.search( From 1c65f7d51d7bfe60c7f77e1fc0cf2f155607ff27 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Tue, 9 Jul 2024 20:57:52 +0800 Subject: [PATCH 07/30] fix: better async session for sqlalchemy --- src/vanna/base/base.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index fa7822a5..b48fb1c8 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -55,7 +55,7 @@ import sqlite3 import traceback from abc import ABC, abstractmethod -from typing import Any, List, Tuple, Union +from typing import Any, AsyncGenerator, List, Tuple, Union from urllib.parse import urlparse import pandas as pd @@ -1374,13 +1374,14 @@ async def aconnect_to_mysql( if not port: port = int(os.getenv("PORT", 3306)) - conn = None + from sqlalchemy.ext.asyncio import async_sessionmaker try: engine = create_async_engine( url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}" ) - conn = await engine.connect() + async_session = async_sessionmaker(engine, expire_on_commit=False) + except Exception as e: raise ValidationError(e) @@ -1388,16 +1389,16 @@ async def arun_sql_mysql(sql: str, **kwargs) -> pd.DataFrame: from sqlalchemy import text try: - cs = await conn.execute(text(sql)) - results = cs.fetchall() + async with async_session() as session: + cs = await session.execute(text(sql)) + results = cs.fetchall() - columns = cs.keys() - # Create a pandas dataframe from the results - df = pd.DataFrame(results, columns=columns) # type: ignore - return df + columns = cs.keys() + # Create a pandas dataframe from the results + df = pd.DataFrame(results, columns=columns) # type: ignore + return df except Exception as e: - await conn.rollback() raise e self.arun_sql_is_set = True From 64e506f9b1e76ec86b61aa78dedb7212c1968da0 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Tue, 9 Jul 2024 21:34:23 +0800 Subject: [PATCH 08/30] feat: set max row to be fed to llm on summary to only 10 rows --- src/vanna/base/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index b48fb1c8..e0edee8c 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -428,6 +428,8 @@ def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str: Returns: str: The summary of the results of the SQL query. """ + # trim the dataframe to 10 rows + df = df.head(10) message_log = [ self.system_message( @@ -459,6 +461,8 @@ async def agenerate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> Returns: str: The summary of the results of the SQL query. """ + # trim the dataframe to 10 rows + df = df.head(10) message_log = [ self.system_message( From 3277bc1d37c96c20f0cbee4f49db1074525c1325 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Tue, 9 Jul 2024 21:57:16 +0800 Subject: [PATCH 09/30] feat: added retry on openrouter asubmit prompt --- src/vanna/openrouter/openrouter_chat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/vanna/openrouter/openrouter_chat.py b/src/vanna/openrouter/openrouter_chat.py index ce03463d..265647d5 100644 --- a/src/vanna/openrouter/openrouter_chat.py +++ b/src/vanna/openrouter/openrouter_chat.py @@ -3,6 +3,7 @@ from openai import AsyncOpenAI, OpenAI from openai.types.chat.chat_completion import Choice +from tenacity import retry from ..base import VannaBase @@ -119,6 +120,7 @@ def submit_prompt( return "" + @retry async def asubmit_prompt( self, prompt, model: str = "deepseek/deepseek-chat", **kwargs ) -> str: From fed4acf951c60ccbe52dfb6fa54727a9bb50445f Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Tue, 9 Jul 2024 22:29:50 +0800 Subject: [PATCH 10/30] chore: rename parsed_df --- src/vanna/base/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index e0edee8c..6d0965ea 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -429,11 +429,11 @@ def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str: str: The summary of the results of the SQL query. """ # trim the dataframe to 10 rows - df = df.head(10) + parsed_df = df.head(10) message_log = [ self.system_message( - f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n" + f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{parsed_df.to_markdown()}\n\n" ), self.user_message( "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." @@ -462,11 +462,11 @@ async def agenerate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str: The summary of the results of the SQL query. """ # trim the dataframe to 10 rows - df = df.head(10) + parsed_df = df.head(10) message_log = [ self.system_message( - f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n" + f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{parsed_df.to_markdown()}\n\n" ), self.user_message( "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." From 096a1ed7639f9adfc31ffb18f1e95d7488c5e522 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Tue, 9 Jul 2024 22:40:41 +0800 Subject: [PATCH 11/30] feat: increased pool size on mysql --- src/vanna/base/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 6d0965ea..c31a6780 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1382,7 +1382,11 @@ async def aconnect_to_mysql( try: engine = create_async_engine( - url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}" + url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}", + echo=True, + pool_size=20, # Increase pool size + max_overflow=30, # Increase max overflow + pool_recycle=3600, # Recycle connections after 1 hour ) async_session = async_sessionmaker(engine, expire_on_commit=False) From a9f9fd0a6e5b029d2ddef95bcfde168547a52b05 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Tue, 9 Jul 2024 22:48:41 +0800 Subject: [PATCH 12/30] feat: session mysql fix on concurrent --- src/vanna/base/base.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index c31a6780..8744a280 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1382,13 +1382,9 @@ async def aconnect_to_mysql( try: engine = create_async_engine( - url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}", - echo=True, - pool_size=20, # Increase pool size - max_overflow=30, # Increase max overflow - pool_recycle=3600, # Recycle connections after 1 hour + url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}" ) - async_session = async_sessionmaker(engine, expire_on_commit=False) + async_session = async_sessionmaker(engine) except Exception as e: raise ValidationError(e) @@ -1397,8 +1393,9 @@ async def arun_sql_mysql(sql: str, **kwargs) -> pd.DataFrame: from sqlalchemy import text try: - async with async_session() as session: + async with async_session.begin() as session: cs = await session.execute(text(sql)) + await session.commit() results = cs.fetchall() columns = cs.keys() From e03795a26ad9ac3b07846e0786fe2458189b995f Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Tue, 9 Jul 2024 22:59:47 +0800 Subject: [PATCH 13/30] feat: session concurrent user fix --- src/vanna/base/base.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 8744a280..4f9ce115 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1392,19 +1392,23 @@ async def aconnect_to_mysql( async def arun_sql_mysql(sql: str, **kwargs) -> pd.DataFrame: from sqlalchemy import text + session = async_session() + try: - async with async_session.begin() as session: - cs = await session.execute(text(sql)) - await session.commit() - results = cs.fetchall() + cs = await session.execute(text(sql)) + results = cs.fetchall() - columns = cs.keys() - # Create a pandas dataframe from the results - df = pd.DataFrame(results, columns=columns) # type: ignore - return df + columns = cs.keys() + # Create a pandas dataframe from the results + df = pd.DataFrame(results, columns=columns) # type: ignore + await session.commit() + return df except Exception as e: + await session.rollback() raise e + finally: + await session.close() self.arun_sql_is_set = True self.arun_sql = arun_sql_mysql From 2efe9edcb318a1449063c584f12fb42ef7fc1577 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Wed, 10 Jul 2024 14:56:52 +0800 Subject: [PATCH 14/30] feat: expose api key --- src/vanna/openrouter/openrouter_chat.py | 57 ++++++------------------- 1 file changed, 12 insertions(+), 45 deletions(-) diff --git a/src/vanna/openrouter/openrouter_chat.py b/src/vanna/openrouter/openrouter_chat.py index 265647d5..5cb01708 100644 --- a/src/vanna/openrouter/openrouter_chat.py +++ b/src/vanna/openrouter/openrouter_chat.py @@ -11,14 +11,22 @@ class OpenRouter_Chat(VannaBase): def __init__( self, - client=None, - aclient=None, + client=OpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"), + ), + aclient=AsyncOpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"), + ), config: dict[str, Any] = {}, ): VannaBase.__init__(self, config=config) # default parameters - can be overrided using config self.temperature = 0.3 self.max_tokens = 4000 + self.client = client + self.aclient = aclient if "temperature" in config: self.temperature = config["temperature"] @@ -26,50 +34,9 @@ def __init__( if "max_tokens" in config: self.max_tokens = config["max_tokens"] - if "api_type" in config: - raise Exception( - "Passing api_type is now deprecated. Please pass an OpenAI client instead." - ) - - if "api_base" in config: - raise Exception( - "Passing api_base is now deprecated. Please pass an OpenAI client instead." - ) - - if "api_version" in config: - raise Exception( - "Passing api_version is now deprecated. Please pass an OpenAI client instead." - ) - - if client is not None: - self.client = client - return - - if config is None and client is None: - self.client = OpenAI( - api_key=os.getenv("OPENAI_API_KEY"), - base_url="https://openrouter.ai/api/v1", - ) - return - - if aclient is not None: - self.aclient = aclient - return - - if config is None and aclient is None: - self.aclient = AsyncOpenAI( - api_key=os.getenv("OPENAI_API_KEY"), - base_url="https://openrouter.ai/api/v1", - ) - return - if "api_key" in config: - self.client = OpenAI( - api_key=config["api_key"], base_url="https://openrouter.ai/api/v1" - ) - self.aclient = AsyncOpenAI( - api_key=config["api_key"], base_url="https://openrouter.ai/api/v1" - ) + self.client.api_key = config["api_key"] + self.aclient.api_key = config["api_key"] def system_message(self, message: str) -> Any: return {"role": "system", "content": message} From 92a44d47867617f8dea67dd01c503e755798a612 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Wed, 10 Jul 2024 15:05:06 +0800 Subject: [PATCH 15/30] feat: expose model through param --- src/vanna/openrouter/openrouter_chat.py | 35 +++++++++---------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/src/vanna/openrouter/openrouter_chat.py b/src/vanna/openrouter/openrouter_chat.py index 5cb01708..0a7857e1 100644 --- a/src/vanna/openrouter/openrouter_chat.py +++ b/src/vanna/openrouter/openrouter_chat.py @@ -23,20 +23,15 @@ def __init__( ): VannaBase.__init__(self, config=config) # default parameters - can be overrided using config - self.temperature = 0.3 - self.max_tokens = 4000 + self.temperature = config.get("temperature", 0.7) + self.max_tokens = config.get("max_tokens", 500) self.client = client self.aclient = aclient - - if "temperature" in config: - self.temperature = config["temperature"] - - if "max_tokens" in config: - self.max_tokens = config["max_tokens"] - - if "api_key" in config: - self.client.api_key = config["api_key"] - self.aclient.api_key = config["api_key"] + self.model = config.get( + "model", os.getenv("OPENROUTER_MODEL", "deepseek/deepseek-chat") + ) + self.client.api_key = config.get("api_key", os.getenv("OPENROUTER_API_KEY")) + self.aclient.api_key = config.get("api_key", os.getenv("OPENROUTER_API_KEY")) def system_message(self, message: str) -> Any: return {"role": "system", "content": message} @@ -88,9 +83,7 @@ def submit_prompt( return "" @retry - async def asubmit_prompt( - self, prompt, model: str = "deepseek/deepseek-chat", **kwargs - ) -> str: + async def asubmit_prompt(self, prompt, **kwargs) -> str: if prompt is None: raise Exception("Prompt is None") @@ -103,9 +96,9 @@ async def asubmit_prompt( for message in prompt: num_tokens += len(message["content"]) / 4 - print(f"Using model {model} for {num_tokens} tokens (approx)") + print(f"Using model {self.model} for {num_tokens} tokens (approx)") response = await self.aclient.chat.completions.create( - model=model, + model=self.model, messages=prompt, max_tokens=self.max_tokens, stop=None, @@ -128,9 +121,7 @@ async def asubmit_prompt( return "" - async def astream_submit_prompt( - self, prompt, model: str = "deepseek/deepseek-chat", **kwargs - ) -> AsyncIterable[str]: + async def astream_submit_prompt(self, prompt, **kwargs) -> AsyncIterable[str]: if prompt is None: raise Exception("Prompt is None") @@ -143,9 +134,9 @@ async def astream_submit_prompt( for message in prompt: num_tokens += len(message["content"]) / 4 - print(f"Using model {model} for {num_tokens} tokens (approx)") + print(f"Using model {self.model} for {num_tokens} tokens (approx)") stream = await self.aclient.chat.completions.create( - model=model, + model=self.model, messages=prompt, max_tokens=self.max_tokens, stop=None, From 9e10d91b6dd36680df7cb91c0110bbd47186e29c Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Wed, 10 Jul 2024 21:13:24 +0800 Subject: [PATCH 16/30] feat: return asyncengine upon mysql connect to cleanup --- src/vanna/base/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 4f9ce115..a185d7f1 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1413,6 +1413,8 @@ async def arun_sql_mysql(sql: str, **kwargs) -> pd.DataFrame: self.arun_sql_is_set = True self.arun_sql = arun_sql_mysql + return engine + def connect_to_clickhouse( self, host: str = None, From 267f1281fac2a872d42b779d38cbed937ff79a73 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Thu, 11 Jul 2024 01:37:43 +0800 Subject: [PATCH 17/30] feat: overflow 0 --- src/vanna/base/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index a185d7f1..fe4533d0 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1382,7 +1382,8 @@ async def aconnect_to_mysql( try: engine = create_async_engine( - url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}" + url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}", + max_overflow=0 ) async_session = async_sessionmaker(engine) From 86b8691148b18b1127fd9dbaace0dd00e29fedbd Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Thu, 11 Jul 2024 01:57:45 +0800 Subject: [PATCH 18/30] feat: increase max overflow and timeout --- src/vanna/base/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index fe4533d0..e113ccdc 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1383,7 +1383,8 @@ async def aconnect_to_mysql( try: engine = create_async_engine( url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}", - max_overflow=0 + max_overflow=-1, + timeout=3600, ) async_session = async_sessionmaker(engine) From e5e5d7713cf3d2b2acc6b612a92d310c5c3c24c7 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Thu, 11 Jul 2024 02:35:00 +0800 Subject: [PATCH 19/30] feat: add pool timeout --- src/vanna/base/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index e113ccdc..36f1ff99 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1383,8 +1383,7 @@ async def aconnect_to_mysql( try: engine = create_async_engine( url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}", - max_overflow=-1, - timeout=3600, + pool_timeout=3600, ) async_session = async_sessionmaker(engine) From 2e2171abc63000bbcd717ab619ed455485970564 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Thu, 11 Jul 2024 13:38:19 +0800 Subject: [PATCH 20/30] feat: improve agenerate summary prompt --- src/vanna/base/base.py | 52 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 36f1ff99..19db5a68 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -445,7 +445,9 @@ def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str: return summary - async def agenerate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str: + async def agenerate_summary( + self, question: str, df: pd.DataFrame, sql: str, **kwargs + ) -> str: """ **Example:** ```python @@ -466,10 +468,54 @@ async def agenerate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> message_log = [ self.system_message( - f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{parsed_df.to_markdown()}\n\n" + f"""Let's think step by step. + You are a helpful data assistant. + Given a set of question, sql query, and sql results, generate an answer to the question.""" + ), + self.assistant_message( + """Certainly! To generate an answer to a question given a SQL query and its results, we can follow these steps: +1. **Understand the Question**: First, we need to clearly understand what the question is asking. This will help us determine what information from the SQL results is relevant. +2. **Analyze the SQL Query**: Review the SQL query to understand what data it is retrieving and how it is structured. This will help us map the results back to the question. +3. **Review the SQL Results**: Look at the results returned by the SQL query. Identify the columns and rows that contain the data relevant to the question. +4. **Generate the Answer**: Using the relevant data from the SQL results, construct an answer to the question. Ensure that the answer is clear, concise, and directly addresses the question. + +Let's go through an example to illustrate this process: + +### Example + +**Question**: How many orders were placed in the month of January 2023? + +**SQL Query**: +```sql +SELECT COUNT(*) AS order_count +FROM orders +WHERE order_date BETWEEN '2023-01-01' AND '2023-01-31'; +``` + +**SQL Results**: +``` +order_count +----------- +500 +``` + +### Step-by-Step Process + +1. **Understand the Question**: The question is asking for the number of orders placed in January 2023. +2. **Analyze the SQL Query**: The SQL query counts the number of orders where the `order_date` falls within January 2023. +3. **Review the SQL Results**: The results show that there were 500 orders. +4. **Generate the Answer**: The answer to the question is: "There were 500 orders placed in the month of January 2023." +By following these steps, we can ensure that the answer is accurate and directly addresses the question based on the provided SQL query and results.""" ), self.user_message( - "Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary." + f"""**Question**: + {question} + + **SQL query**: + {sql} + + **SQL results (limited to 10 rows)**: + {parsed_df.to_markdown()}""" + self._response_language() ), ] From 09be181e66acc90888aced20ff1da1cf8cc43a52 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Thu, 11 Jul 2024 13:58:50 +0800 Subject: [PATCH 21/30] feat: improved prompt on agenerate summary --- src/vanna/base/base.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 19db5a68..98477f59 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -469,8 +469,8 @@ async def agenerate_summary( message_log = [ self.system_message( f"""Let's think step by step. - You are a helpful data assistant. - Given a set of question, sql query, and sql results, generate an answer to the question.""" +You are a helpful data assistant. +Given a set of question, sql query, and sql results, generate an answer to the question.""" ), self.assistant_message( """Certainly! To generate an answer to a question given a SQL query and its results, we can follow these steps: @@ -509,14 +509,17 @@ async def agenerate_summary( ), self.user_message( f"""**Question**: - {question} +{question} - **SQL query**: - {sql} - - **SQL results (limited to 10 rows)**: - {parsed_df.to_markdown()}""" - + self._response_language() +**SQL query**: +``` +{sql} +``` +**SQL results (limited to 10 rows)**: +``` +{parsed_df.to_markdown()} +``` +""" ), ] From f00986c58935e163100ebea7d17ee2ad643d0b60 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Thu, 11 Jul 2024 14:44:36 +0800 Subject: [PATCH 22/30] feat: async with instead of finally --- src/vanna/base/base.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 98477f59..67291a8f 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1442,23 +1442,15 @@ async def aconnect_to_mysql( async def arun_sql_mysql(sql: str, **kwargs) -> pd.DataFrame: from sqlalchemy import text - session = async_session() - - try: - cs = await session.execute(text(sql)) - results = cs.fetchall() - - columns = cs.keys() - # Create a pandas dataframe from the results - df = pd.DataFrame(results, columns=columns) # type: ignore - await session.commit() - return df + async with async_session() as session: + async with session.begin(): + cs = await session.execute(text(sql)) + results = cs.fetchall() - except Exception as e: - await session.rollback() - raise e - finally: - await session.close() + columns = cs.keys() + df = pd.DataFrame(results, columns=columns) # type: ignore + await session.rollback() + return df self.arun_sql_is_set = True self.arun_sql = arun_sql_mysql From de953ea057f9d0d20742f9729009ee4022d8ab3c Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Fri, 12 Jul 2024 23:01:40 +0800 Subject: [PATCH 23/30] feat: expose client on a connect to mysql --- src/vanna/base/base.py | 66 ++++++++---------------------------------- 1 file changed, 12 insertions(+), 54 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 67291a8f..0e4f10c6 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -64,6 +64,7 @@ import plotly.graph_objects as go import requests import sqlparse +from sqlalchemy.ext.asyncio import create_async_engine from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError from ..types import TrainingPlan, TrainingPlanItem @@ -1385,72 +1386,29 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]: async def aconnect_to_mysql( self, - host: str | None = None, - dbname: str | None = None, - user: str | None = None, - password: str | None = None, - port: int | None = None, + engine=create_async_engine( + url=f"mysql+aiomysql://{os.getenv("DB_USER")}:{os.getenv('DB_PASSWORD')}@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}", + pool_timeout=3600, + ), ): - - try: - from sqlalchemy.ext.asyncio import create_async_engine - except ImportError: - raise DependencyError( - "You need to install required dependencies to execute this method," - " run command: \npip install aiomysql" - ) - - if not host: - host = os.getenv("HOST") - - if not host: - raise ImproperlyConfigured("Please set your MySQL host") - - if not dbname: - dbname = os.getenv("DATABASE") - - if not dbname: - raise ImproperlyConfigured("Please set your MySQL database") - - if not user: - user = os.getenv("USER") - - if not user: - raise ImproperlyConfigured("Please set your MySQL user") - - if not password: - password = os.getenv("PASSWORD") - - if not password: - raise ImproperlyConfigured("Please set your MySQL password") - - if not port: - port = int(os.getenv("PORT", 3306)) - from sqlalchemy.ext.asyncio import async_sessionmaker - try: - engine = create_async_engine( - url=f"mysql+aiomysql://{user}:{password}@{host}:{port}/{dbname}", - pool_timeout=3600, - ) - async_session = async_sessionmaker(engine) - - except Exception as e: - raise ValidationError(e) + async_session = async_sessionmaker(engine) async def arun_sql_mysql(sql: str, **kwargs) -> pd.DataFrame: from sqlalchemy import text - async with async_session() as session: - async with session.begin(): + async with async_session.begin() as session: + try: cs = await session.execute(text(sql)) results = cs.fetchall() - columns = cs.keys() df = pd.DataFrame(results, columns=columns) # type: ignore - await session.rollback() return df + except Exception as e: + raise ValidationError(e) + finally: + await session.rollback() self.arun_sql_is_set = True self.arun_sql = arun_sql_mysql From 690953116d9231b330a89fffbd1ea363768c4988 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Fri, 12 Jul 2024 23:09:30 +0800 Subject: [PATCH 24/30] fix: aconnectmysql --- src/vanna/base/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 0e4f10c6..91035d72 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1387,7 +1387,7 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]: async def aconnect_to_mysql( self, engine=create_async_engine( - url=f"mysql+aiomysql://{os.getenv("DB_USER")}:{os.getenv('DB_PASSWORD')}@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}", + url=f"""mysql+aiomysql://{os.getenv("DB_USER")}:{os.getenv('DB_PASSWORD')}@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}""", pool_timeout=3600, ), ): From 0365e7c0f08daa07e6bed7bf997617271227bd57 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Fri, 12 Jul 2024 23:12:39 +0800 Subject: [PATCH 25/30] fix: port --- src/vanna/base/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 91035d72..eeaad1f8 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1387,7 +1387,7 @@ def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]: async def aconnect_to_mysql( self, engine=create_async_engine( - url=f"""mysql+aiomysql://{os.getenv("DB_USER")}:{os.getenv('DB_PASSWORD')}@{os.getenv('DB_HOST')}:{os.getenv('DB_PORT')}/{os.getenv('DB_NAME')}""", + url=f"""mysql+aiomysql://{os.getenv("DB_USER","")}:{os.getenv('DB_PASSWORD',"")}@{os.getenv('DB_HOST',"")}:{int(os.getenv('DB_PORT',3306))}/{os.getenv('DB_NAME',"")}""", pool_timeout=3600, ), ): From 31188d6d0ecff3fbe4eb3e7c7722586e1e1d94f3 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Fri, 19 Jul 2024 18:40:08 +0800 Subject: [PATCH 26/30] feat: improve sql agent summary --- src/vanna/base/base.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index eeaad1f8..ff6b6fdc 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -471,20 +471,18 @@ async def agenerate_summary( self.system_message( f"""Let's think step by step. You are a helpful data assistant. -Given a set of question, sql query, and sql results, generate an answer to the question.""" - ), - self.assistant_message( - """Certainly! To generate an answer to a question given a SQL query and its results, we can follow these steps: -1. **Understand the Question**: First, we need to clearly understand what the question is asking. This will help us determine what information from the SQL results is relevant. -2. **Analyze the SQL Query**: Review the SQL query to understand what data it is retrieving and how it is structured. This will help us map the results back to the question. -3. **Review the SQL Results**: Look at the results returned by the SQL query. Identify the columns and rows that contain the data relevant to the question. -4. **Generate the Answer**: Using the relevant data from the SQL results, construct an answer to the question. Ensure that the answer is clear, concise, and directly addresses the question. - -Let's go through an example to illustrate this process: +Given a set of question, sql query, and sql results, generate an answer to the question. -### Example +### Step-by-Step Process -**Question**: How many orders were placed in the month of January 2023? +1. **Understand the Question**: The question is asking for the number of orders placed in January 2023. +2. **Analyze the SQL Query**: The SQL query counts the number of orders where the `order_date` falls within January 2023. +3. **Review the SQL Results**: The results show that there were 500 orders. +4. **Generate the Answer**: The answer to the question is: "There were 500 orders placed in the month of January 2023." +By following these steps, we can ensure that the answer is accurate and directly addresses the question based on the provided SQL query and results.""" + ), + self.user_message( + """**Question**: How many orders were placed in the month of January 2023? **SQL Query**: ```sql @@ -493,20 +491,16 @@ async def agenerate_summary( WHERE order_date BETWEEN '2023-01-01' AND '2023-01-31'; ``` -**SQL Results**: +**SQL Results (limited to 10 rows)**: ``` order_count ----------- 500 ``` - -### Step-by-Step Process - -1. **Understand the Question**: The question is asking for the number of orders placed in January 2023. -2. **Analyze the SQL Query**: The SQL query counts the number of orders where the `order_date` falls within January 2023. -3. **Review the SQL Results**: The results show that there were 500 orders. -4. **Generate the Answer**: The answer to the question is: "There were 500 orders placed in the month of January 2023." -By following these steps, we can ensure that the answer is accurate and directly addresses the question based on the provided SQL query and results.""" +""" + ), + self.assistant_message( + """There were 500 orders placed in the month of January 2023.""" ), self.user_message( f"""**Question**: From 180233b099bb53ab598b278d9b8014383ad089d1 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Fri, 19 Jul 2024 21:54:37 +0800 Subject: [PATCH 27/30] fix: agenerate summary prompt --- src/vanna/base/base.py | 40 ++++++++++------------------------------ 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index ff6b6fdc..53ca4447 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -471,7 +471,7 @@ async def agenerate_summary( self.system_message( f"""Let's think step by step. You are a helpful data assistant. -Given a set of question, sql query, and sql results, generate an answer to the question. +Given a set of question, sql query, and sql results, generate an answer to the original question. ### Step-by-Step Process @@ -479,42 +479,22 @@ async def agenerate_summary( 2. **Analyze the SQL Query**: The SQL query counts the number of orders where the `order_date` falls within January 2023. 3. **Review the SQL Results**: The results show that there were 500 orders. 4. **Generate the Answer**: The answer to the question is: "There were 500 orders placed in the month of January 2023." -By following these steps, we can ensure that the answer is accurate and directly addresses the question based on the provided SQL query and results.""" - ), - self.user_message( - """**Question**: How many orders were placed in the month of January 2023? +5. Your answer should be as detailed as possible, copy the sql result if necessary. + +By following these steps, we can ensure that the answer is accurate and directly addresses the question based on the provided SQL query and results. + +Example: +**Question**: {question} **SQL Query**: ```sql -SELECT COUNT(*) AS order_count -FROM orders -WHERE order_date BETWEEN '2023-01-01' AND '2023-01-31'; +{sql} ``` **SQL Results (limited to 10 rows)**: -``` -order_count ------------ -500 -``` -""" - ), - self.assistant_message( - """There were 500 orders placed in the month of January 2023.""" - ), - self.user_message( - f"""**Question**: -{question} - -**SQL query**: -``` -{sql} -``` -**SQL results (limited to 10 rows)**: -``` +```md {parsed_df.to_markdown()} -``` -""" +```""" ), ] From e12df614e38580c9db3b675f292a089abfe11100 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Wed, 7 Aug 2024 22:55:19 +0800 Subject: [PATCH 28/30] perf: faster ageneratesql --- src/vanna/base/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 53ca4447..c09f2143 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -178,9 +178,11 @@ async def agenerate_sql( initial_prompt = self.config.get("initial_prompt", None) else: initial_prompt = None - question_sql_list = await self.aget_similar_question_sql(question, **kwargs) - ddl_list = await self.aget_related_ddl(question, **kwargs) - doc_list = await self.aget_related_documentation(question, **kwargs) + coros = [] + coros.append(self.aget_similar_question_sql(question, **kwargs)) + coros.append(self.aget_related_ddl(question, **kwargs)) + coros.append(self.aget_related_documentation(question, **kwargs)) + question_sql_list, ddl_list, doc_list = await asyncio.gather(*coros) prompt = self.get_sql_prompt( initial_prompt=initial_prompt or "", question=question, From f6f7225785d1e4f1de78addd8cd81d2ff8e24f60 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Wed, 7 Aug 2024 23:14:51 +0800 Subject: [PATCH 29/30] fix: prevent context overflow on intermediate sql --- src/vanna/base/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index c09f2143..a2c6d50d 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -213,8 +213,8 @@ async def agenerate_sql( ddl_list=ddl_list, doc_list=doc_list + [ - f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" - + df.to_markdown() + f"The following is a pandas DataFrame (limited to 10 rows) with the results of the intermediate SQL query {intermediate_sql}: \n" + + df.head(10).to_markdown() ], **kwargs, ) From 1728c2f3b9ea48fe964a189591db7d52bf1209d9 Mon Sep 17 00:00:00 2001 From: vikyw89 Date: Wed, 21 Aug 2024 15:37:10 +0800 Subject: [PATCH 30/30] perf: fix lost in the middle on long context --- src/vanna/base/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index a2c6d50d..f409dd35 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -183,6 +183,10 @@ async def agenerate_sql( coros.append(self.aget_related_ddl(question, **kwargs)) coros.append(self.aget_related_documentation(question, **kwargs)) question_sql_list, ddl_list, doc_list = await asyncio.gather(*coros) + + # reverse due to llm lost in the middle + # top result should get better attention (at the end) + question_sql_list.reverse() prompt = self.get_sql_prompt( initial_prompt=initial_prompt or "", question=question,