Skip to content

Commit

Permalink
fix: Cherry-pick upstream changes from deepset-ai/haystack#8572
Browse files Browse the repository at this point in the history
  • Loading branch information
shadeMe committed Nov 25, 2024
1 parent fee5178 commit ba09cec
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
22 changes: 14 additions & 8 deletions haystack_experimental/components/builders/chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Literal, Optional, Set, Union

from haystack import component, default_from_dict, default_to_dict, logging
from jinja2 import meta
Expand Down Expand Up @@ -104,7 +104,7 @@ class ChatPromptBuilder:
def __init__(
self,
template: Optional[List[ChatMessage]] = None,
required_variables: Optional[List[str]] = None,
required_variables: Optional[Union[List[str], Literal["*"]]] = None,
variables: Optional[List[str]] = None,
):
"""
Expand All @@ -116,7 +116,8 @@ def __init__(
the `init` method` or the `run` method.
:param required_variables:
List variables that must be provided as input to ChatPromptBuilder.
If a variable listed as required is not provided, an exception is raised. Optional.
If a variable listed as required is not provided, an exception is raised.
If set to "*", all variables found in the prompt are required.
:param variables:
List input variables to use in prompt templates instead of the ones inferred from the
`template` parameter. For example, to use more variables during prompt engineering than the ones present
Expand All @@ -131,7 +132,7 @@ def __init__(
if template and not variables:
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
# infere variables from template
# infer variables from template
if message.text is None:
raise ValueError(
f"The {self.__class__.__name__} requires a non-empty list of ChatMessage"
Expand All @@ -140,10 +141,11 @@ def __init__(
ast = self._env.parse(message.text)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
self.variables = variables

# setup inputs
for var in variables:
if var in self.required_variables:
for var in self.variables:
if self.required_variables == "*" or var in self.required_variables:
component.set_input_type(self, var, Any)
else:
component.set_input_type(self, var, Any, "")
Expand Down Expand Up @@ -224,12 +226,16 @@ def _validate_variables(self, provided_variables: Set[str]):
:raises ValueError:
If no template is provided or if all the required template variables are not provided.
"""
missing_variables = [var for var in self.required_variables if var not in provided_variables]
if self.required_variables == "*":
required_variables = sorted(self.variables)
else:
required_variables = self.required_variables
missing_variables = [var for var in required_variables if var not in provided_variables]
if missing_variables:
missing_vars_str = ", ".join(missing_variables)
raise ValueError(
f"Missing required input variables in ChatPromptBuilder: {missing_vars_str}. "
f"Required variables: {self.required_variables}. Provided variables: {provided_variables}."
f"Required variables: {required_variables}. Provided variables: {provided_variables}."
)

def to_dict(self) -> Dict[str, Any]:
Expand Down
12 changes: 12 additions & 0 deletions test/components/builders/test_chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ def test_run_template_variable_overrides_variable(self):
"prompt": [ChatMessage.from_user("This is a test_from_template_var")]
}

def test_run_with_missing_required_input_using_star(self):
builder = ChatPromptBuilder(
template=[ChatMessage.from_user("This is a {{ foo }}, not a {{ bar }}")],
required_variables="*",
)
with pytest.raises(ValueError, match="foo"):
builder.run(bar="bar")
with pytest.raises(ValueError, match="bar"):
builder.run(foo="foo")
with pytest.raises(ValueError, match="bar, foo"):
builder.run()

def test_run_without_input(self):
builder = ChatPromptBuilder(
template=[ChatMessage.from_user("This is a template without input")]
Expand Down

0 comments on commit ba09cec

Please sign in to comment.