diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 696ad681..cc5af61c 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -501,16 +501,20 @@ def str_to_approx_token_count(self, string: str) -> int: def add_ddl_to_prompt( self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000 ) -> str: - if len(ddl_list) > 0: - initial_prompt += "\n===Tables \n" + prefix_prompt = "\n===Tables \n" + ddl_prompts = [] + if len(ddl_list) > 0: for ddl in ddl_list: if ( self.str_to_approx_token_count(initial_prompt) + self.str_to_approx_token_count(ddl) < max_tokens ): - initial_prompt += f"{ddl}\n\n" + ddl_prompts.append(f"{ddl}\n\n") + + if ddl_prompts: + initial_prompt += prefix_prompt + "".join(ddl_prompts) return initial_prompt @@ -520,32 +524,40 @@ def add_documentation_to_prompt( documentation_list: list[str], max_tokens: int = 14000, ) -> str: - if len(documentation_list) > 0: - initial_prompt += "\n===Additional Context \n\n" + prefix_prompt = "\n===Additional Context \n\n" + documentation_prompts = [] + if len(documentation_list) > 0: for documentation in documentation_list: if ( self.str_to_approx_token_count(initial_prompt) + self.str_to_approx_token_count(documentation) < max_tokens ): - initial_prompt += f"{documentation}\n\n" + documentation_prompts.append(f"{documentation}\n\n") + + if documentation_prompts: + initial_prompt += prefix_prompt + "".join(documentation_prompts) return initial_prompt def add_sql_to_prompt( self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000 ) -> str: - if len(sql_list) > 0: - initial_prompt += "\n===Question-SQL Pairs\n\n" + prefix_prompt = "\n===Question-SQL Pairs\n\n" + sql_prompts = [] + if len(sql_list) > 0: for question in sql_list: if ( self.str_to_approx_token_count(initial_prompt) + self.str_to_approx_token_count(question["sql"]) < max_tokens ): - initial_prompt += f"{question['question']}\n{question['sql']}\n\n" + sql_prompts.append(f"{question['question']}\n{question['sql']}\n\n") + + if sql_prompts: + initial_prompt += prefix_prompt + "".join(sql_prompts) return initial_prompt