From ba09cec467419f2968607162b5156d434b97868e Mon Sep 17 00:00:00 2001 From: shadeMe Date: Mon, 25 Nov 2024 16:51:20 +0100 Subject: [PATCH] fix: Cherry-pick upstream changes from https://github.com/deepset-ai/haystack/pull/8572 --- .../builders/chat_prompt_builder.py | 22 ++++++++++++------- .../builders/test_chat_prompt_builder.py | 12 ++++++++++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/haystack_experimental/components/builders/chat_prompt_builder.py b/haystack_experimental/components/builders/chat_prompt_builder.py index d1c6f7b..38bc181 100644 --- a/haystack_experimental/components/builders/chat_prompt_builder.py +++ b/haystack_experimental/components/builders/chat_prompt_builder.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from copy import deepcopy -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Literal, Optional, Set, Union from haystack import component, default_from_dict, default_to_dict, logging from jinja2 import meta @@ -104,7 +104,7 @@ class ChatPromptBuilder: def __init__( self, template: Optional[List[ChatMessage]] = None, - required_variables: Optional[List[str]] = None, + required_variables: Optional[Union[List[str], Literal["*"]]] = None, variables: Optional[List[str]] = None, ): """ @@ -116,7 +116,8 @@ def __init__( the `init` method` or the `run` method. :param required_variables: List variables that must be provided as input to ChatPromptBuilder. - If a variable listed as required is not provided, an exception is raised. Optional. + If a variable listed as required is not provided, an exception is raised. + If set to "*", all variables found in the prompt are required. :param variables: List input variables to use in prompt templates instead of the ones inferred from the `template` parameter. For example, to use more variables during prompt engineering than the ones present @@ -131,7 +132,7 @@ def __init__( if template and not variables: for message in template: if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM): - # infere variables from template + # infer variables from template if message.text is None: raise ValueError( f"The {self.__class__.__name__} requires a non-empty list of ChatMessage" @@ -140,10 +141,11 @@ def __init__( ast = self._env.parse(message.text) template_variables = meta.find_undeclared_variables(ast) variables += list(template_variables) + self.variables = variables # setup inputs - for var in variables: - if var in self.required_variables: + for var in self.variables: + if self.required_variables == "*" or var in self.required_variables: component.set_input_type(self, var, Any) else: component.set_input_type(self, var, Any, "") @@ -224,12 +226,16 @@ def _validate_variables(self, provided_variables: Set[str]): :raises ValueError: If no template is provided or if all the required template variables are not provided. """ - missing_variables = [var for var in self.required_variables if var not in provided_variables] + if self.required_variables == "*": + required_variables = sorted(self.variables) + else: + required_variables = self.required_variables + missing_variables = [var for var in required_variables if var not in provided_variables] if missing_variables: missing_vars_str = ", ".join(missing_variables) raise ValueError( f"Missing required input variables in ChatPromptBuilder: {missing_vars_str}. " - f"Required variables: {self.required_variables}. Provided variables: {provided_variables}." + f"Required variables: {required_variables}. Provided variables: {provided_variables}." ) def to_dict(self) -> Dict[str, Any]: diff --git a/test/components/builders/test_chat_prompt_builder.py b/test/components/builders/test_chat_prompt_builder.py index 703b270..cf4791d 100644 --- a/test/components/builders/test_chat_prompt_builder.py +++ b/test/components/builders/test_chat_prompt_builder.py @@ -140,6 +140,18 @@ def test_run_template_variable_overrides_variable(self): "prompt": [ChatMessage.from_user("This is a test_from_template_var")] } + def test_run_with_missing_required_input_using_star(self): + builder = ChatPromptBuilder( + template=[ChatMessage.from_user("This is a {{ foo }}, not a {{ bar }}")], + required_variables="*", + ) + with pytest.raises(ValueError, match="foo"): + builder.run(bar="bar") + with pytest.raises(ValueError, match="bar"): + builder.run(foo="foo") + with pytest.raises(ValueError, match="bar, foo"): + builder.run() + def test_run_without_input(self): builder = ChatPromptBuilder( template=[ChatMessage.from_user("This is a template without input")]