From bad887b700ea3abbf408917c9e64e972fb7bb757 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 30 Oct 2024 12:10:34 +0100 Subject: [PATCH] feat: support for tools in `AnthropicChatGenerator` (#118) * Initial implementation * Tool use for streaming * Add system message support, proper generation_kwargs handling * system message format * Minor * Minor fixes * Improve pydocs * Check duplicate tools * Update haystack_experimental/components/generators/anthropic/__init__.py Co-authored-by: Stefano Fiorucci * Update haystack_experimental/components/generators/anthropic/chat/chat_generator.py Co-authored-by: Stefano Fiorucci * Update haystack_experimental/components/generators/anthropic/chat/chat_generator.py Co-authored-by: Stefano Fiorucci * Update haystack_experimental/components/generators/anthropic/chat/chat_generator.py Co-authored-by: Stefano Fiorucci * PR feedback * Fix class pydoc * PR feedback * Handle tools round trip * PR feedback * Small lint * Parallel tool support * Special handing for ToolCallResult messages * Simplify _convert_messages_to_anthropic_format * More tests, improve coverage * Increase coverage * Update haystack_experimental/components/generators/anthropic/chat/chat_generator.py Co-authored-by: Stefano Fiorucci * PR feedback * PR feedback - rename method * try simplification * PR feedback - pydoc updates * pylint * Simplify main message conversion algorithm * further simplifaction + test * update readme --------- Co-authored-by: Stefano Fiorucci --- README.md | 4 +- docs/pydoc/config/generators_api.yml | 1 + haystack_experimental/components/__init__.py | 2 + .../generators/anthropic/__init__.py | 7 + .../generators/anthropic/chat/__init__.py | 7 + .../anthropic/chat/chat_generator.py | 458 +++++++++ .../generators/anthropic/test_anthropic.py | 893 ++++++++++++++++++ 7 files changed, 1371 insertions(+), 1 deletion(-) create mode 100644 haystack_experimental/components/generators/anthropic/__init__.py create mode 100644 haystack_experimental/components/generators/anthropic/chat/__init__.py create mode 100644 haystack_experimental/components/generators/anthropic/chat/chat_generator.py create mode 100644 test/components/generators/anthropic/test_anthropic.py diff --git a/README.md b/README.md index 01d523a8..67e9abe8 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ The latest version of the package contains the following experiments: | [`EvaluationHarness`][1] | Evaluation orchestrator | October 2024 | None | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/74) | | [`OpenAIFunctionCaller`][2] | Function Calling Component | October 2024 | None | 🔜 | | | [`OpenAPITool`][3] | OpenAPITool component | October 2024 | jsonref | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/79)| -| Support for Tools: [refactored `ChatMessage` dataclass][10], [`Tool` dataclass][4], [refactored `OpenAIChatGenerator`][11], [refactored `OllamaChatGenerator`][14], [`ToolInvoker` component][12] | Tool Calling support | November 2024 | jsonschema | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/98)| +| Support for Tools: [refactored `ChatMessage` dataclass][10], [`Tool` dataclass][4], [refactored `OpenAIChatGenerator`][11], [refactored `OllamaChatGenerator`][14], [refactored `HuggingFaceAPIChatGenerator`][15], [refactored `AnthropicChatGenerator`][16], [`ToolInvoker` component][12] | Tool Calling support | November 2024 | jsonschema | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/98)| | [`ChatMessageWriter`][5] | Memory Component | December 2024 | None | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) | | [`ChatMessageRetriever`][6] | Memory Component | December 2024 | None | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) | | [`InMemoryChatMessageStore`][7] | Memory Store | December 2024 | None | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) | @@ -62,6 +62,8 @@ The latest version of the package contains the following experiments: [12]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/tools/tool_invoker.py [13]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/extractors/llm_metadata_extractor.py [14]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/generators/ollama/chat/chat_generator.py +[15]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/generators/chat/hugging_face_api.py +[16]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/generators/anthropic/chat/chat_generator.py ## Usage diff --git a/docs/pydoc/config/generators_api.yml b/docs/pydoc/config/generators_api.yml index ed071cd8..18b932c5 100644 --- a/docs/pydoc/config/generators_api.yml +++ b/docs/pydoc/config/generators_api.yml @@ -3,6 +3,7 @@ loaders: search_path: [../../../] modules: ["haystack_experimental.components.generators.chat.openai", "haystack_experimental.components.generators.chat.hugging_face_api", + "haystack_experimental.components.generators.anthropic.chat.chat_generator", "haystack_experimental.components.generators.ollama.chat.chat_generator"] ignore_when_discovered: ["__init__"] processors: diff --git a/haystack_experimental/components/__init__.py b/haystack_experimental/components/__init__.py index db85c6e6..7eaba976 100644 --- a/haystack_experimental/components/__init__.py +++ b/haystack_experimental/components/__init__.py @@ -4,6 +4,7 @@ from .extractors import LLMMetadataExtractor +from .generators.anthropic.chat.chat_generator import AnthropicChatGenerator from .generators.chat import HuggingFaceAPIChatGenerator, OpenAIChatGenerator from .generators.ollama.chat.chat_generator import OllamaChatGenerator from .retrievers.auto_merging_retriever import AutoMergingRetriever @@ -19,6 +20,7 @@ "HuggingFaceAPIChatGenerator", "OllamaChatGenerator", "OpenAIChatGenerator", + "AnthropicChatGenerator", "LLMMetadataExtractor", "HierarchicalDocumentSplitter", "OpenAIFunctionCaller", diff --git a/haystack_experimental/components/generators/anthropic/__init__.py b/haystack_experimental/components/generators/anthropic/__init__.py new file mode 100644 index 00000000..88c4a609 --- /dev/null +++ b/haystack_experimental/components/generators/anthropic/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .chat.chat_generator import AnthropicChatGenerator + +__all__ = ["AnthropicChatGenerator"] diff --git a/haystack_experimental/components/generators/anthropic/chat/__init__.py b/haystack_experimental/components/generators/anthropic/chat/__init__.py new file mode 100644 index 00000000..3d8dd02f --- /dev/null +++ b/haystack_experimental/components/generators/anthropic/chat/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.components.generators.anthropic.chat.chat_generator import AnthropicChatGenerator + +__all__ = ["AnthropicChatGenerator"] diff --git a/haystack_experimental/components/generators/anthropic/chat/chat_generator.py b/haystack_experimental/components/generators/anthropic/chat/chat_generator.py new file mode 100644 index 00000000..784cfd18 --- /dev/null +++ b/haystack_experimental/components/generators/anthropic/chat/chat_generator.py @@ -0,0 +1,458 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Type + +from haystack import component, default_from_dict +from haystack.dataclasses import StreamingChunk +from haystack.lazy_imports import LazyImport +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace + +from haystack_experimental.dataclasses import ChatMessage, ToolCall +from haystack_experimental.dataclasses.chat_message import ChatRole, ToolCallResult +from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace + +logger = logging.getLogger(__name__) + + +with LazyImport("Run 'pip install anthropic-haystack'") as anthropic_integration_import: + # pylint: disable=import-error + from haystack_integrations.components.generators.anthropic import ( + AnthropicChatGenerator as AnthropicChatGeneratorBase, + ) + + from anthropic import Stream + + +# The following code block ensures that: +# - we reuse existing code where possible +# - people can use haystack-experimental without installing anthropic-haystack. +# +# If anthropic-haystack is installed: all works correctly. +# +# If anthropic-haystack is not installed: +# - haystack-experimental package works fine (no import errors). +# - AnthropicChatGenerator fails with ImportError at init (due to anthropic_integration_import.check()). + +if anthropic_integration_import.is_successful(): + chatgenerator_base_class: Type[AnthropicChatGeneratorBase] = AnthropicChatGeneratorBase +else: + chatgenerator_base_class: Type[object] = object # type: ignore[no-redef] + + +def _update_anthropic_message_with_tool_call_results( + tool_call_results: List[ToolCallResult], anthropic_msg: Dict[str, Any] +) -> None: + """ + Update an Anthropic message with tool call results. + + :param tool_call_results: The list of ToolCallResults to update the message with. + :param anthropic_msg: The Anthropic message to update. + """ + if "content" not in anthropic_msg: + anthropic_msg["content"] = [] + + for tool_call_result in tool_call_results: + if tool_call_result.origin.id is None: + raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.") + anthropic_msg["content"].append( + { + "type": "tool_result", + "tool_use_id": tool_call_result.origin.id, + "content": [{"type": "text", "text": tool_call_result.result}], + "is_error": tool_call_result.error, + } + ) + + +def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[Dict[str, Any]]: + """ + Convert a list of tool calls to the format expected by Anthropic Chat API. + + :param tool_calls: The list of ToolCalls to convert. + :return: A list of dictionaries in the format expected by Anthropic API. + """ + anthropic_tool_calls = [] + for tc in tool_calls: + if tc.id is None: + raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.") + anthropic_tool_calls.append( + { + "type": "tool_use", + "id": tc.id, + "name": tc.tool_name, + "input": tc.arguments, + } + ) + return anthropic_tool_calls + + +def _convert_messages_to_anthropic_format( + messages: List[ChatMessage], +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Convert a list of messages to the format expected by Anthropic Chat API. + + :param messages: The list of ChatMessages to convert. + :return: A tuple of two lists: + - A list of system message dictionaries in the format expected by Anthropic API. + - A list of non-system message dictionaries in the format expected by Anthropic API. + """ + + anthropic_system_messages = [] + anthropic_non_system_messages = [] + + i = 0 + while i < len(messages): + message = messages[i] + + # system messages have special format requirements for Anthropic API + # they can have only type and text fields, and they need to be passed separately + # to the Anthropic API endpoint + if message.is_from(ChatRole.SYSTEM): + anthropic_system_messages.append({"type": "text", "text": message.text}) + i += 1 + continue + + anthropic_msg: Dict[str, Any] = {"role": message._role.value, "content": []} + + if message.texts: + anthropic_msg["content"].append({"type": "text", "text": message.texts[0]}) + if message.tool_calls: + anthropic_msg["content"] += _convert_tool_calls_to_anthropic_format(message.tool_calls) + + if message.tool_call_results: + results = message.tool_call_results.copy() + # Handle consecutive tool call results + while (i + 1) < len(messages) and messages[i + 1].tool_call_results: + i += 1 + results.extend(messages[i].tool_call_results) + + _update_anthropic_message_with_tool_call_results(results, anthropic_msg) + anthropic_msg["role"] = "user" + + if not anthropic_msg["content"]: + raise ValueError( + "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`." + ) + + anthropic_non_system_messages.append(anthropic_msg) + i += 1 + + return anthropic_system_messages, anthropic_non_system_messages + + +def _check_duplicate_tool_names(tools: List[Tool]) -> None: + """ + Check for duplicate tool names. + + :param tools: The list of tools to check. + :raises ValueError: If duplicate tool names are found. + """ + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + + +@component +class AnthropicChatGenerator(chatgenerator_base_class): + """ + Completes chats using Anthropic's large language models (LLMs). + + It uses [ChatMessage](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage) + format in input and output. + + You can customize how the text is generated by passing parameters to the + Anthropic API. Use the `**generation_kwargs` argument when you initialize + the component or when you run it. Any parameter that works with + `anthropic.Message.create` will work here too. + + For details on Anthropic API parameters, see + [Anthropic documentation](https://docs.anthropic.com/en/api/messages). + + Usage example: + ```python + from haystack_experimental.components.generators.anthropic import AnthropicChatGenerator + from haystack_experimental.dataclasses import ChatMessage + + generator = AnthropicChatGenerator(model="claude-3-5-sonnet-20240620", + generation_kwargs={ + "max_tokens": 1000, + "temperature": 0.7, + }) + + messages = [ChatMessage.from_system("You are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + print(generator.run(messages=messages)) + ``` + """ + + def __init__( + self, + api_key: Secret = Secret.from_env_var("ANTHROPIC_API_KEY"), + model: str = "claude-3-5-sonnet-20240620", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ignore_tools_thinking_messages: bool = True, + tools: Optional[List[Tool]] = None, + ): + """ + Creates an instance of AnthropicChatGenerator. + + :param api_key: The Anthropic API key. + You can set it with an environment variable `ANTHROPIC_API_KEY`, or pass with this parameter + as a Secret during initialization. + :param model: The name of the Anthropic model to use. Specify one of the Anthropic models with + their Anthropic API names listed in the + [Anthropic documentation](https://docs.anthropic.com/en/docs/about-claude/models). + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param generation_kwargs: Additional parameters to use for the model. These parameters are sent directly to + the Anthropic API. See Anthropic's documentation for more details on available parameters. + Supported generation_kwargs parameters are: + - `system`: The system message to be passed to the model. + - `max_tokens`: The maximum number of tokens to generate. + - `metadata`: A dictionary of metadata to be passed to the model. + - `stop_sequences`: A list of strings that the model should stop generating at. + - `temperature`: The temperature to use for sampling. + - `top_p`: The top_p value to use for nucleus sampling. + - `top_k`: The top_k value to use for top-k sampling. + :param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a + "chain of thought" messages before returning the actual function names and parameters in a message. If + `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool + use is detected. + See the Anthropic [tools](https://docs.anthropic.com/en/docs/build-with-claude/tool-use#chain-of-thought-tool-use) + for more details. + :param tools: A list of Tool objects that the model can use. Each tool should have a unique name. + """ + anthropic_integration_import.check() + + super(AnthropicChatGenerator, self).__init__( + model=model, + api_key=api_key, + generation_kwargs=generation_kwargs, + streaming_callback=streaming_callback, + ) + + if tools: + _check_duplicate_tool_names(tools) + self.tools = tools + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + serialized = super(AnthropicChatGenerator, self).to_dict() + serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None + return serialized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnthropicChatGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + deserialize_tools_inplace(data["init_parameters"], key="tools") + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + + return default_from_dict(cls, data) + + def _convert_chat_completion_to_chat_message(self, anthropic_response: Any) -> ChatMessage: + """ + Converts the response from the Anthropic API to a ChatMessage. + """ + text_extracted = "" + tool_calls = [] + + for content_block in anthropic_response.content: + if content_block.type == "text": + text_extracted = content_block.text + elif content_block.type == "tool_use": + tool_calls.append( + ToolCall( + tool_name=content_block.name, + arguments=content_block.input, # dict already + id=content_block.id, + ) + ) + + message = ChatMessage.from_assistant(text=text_extracted, tool_calls=tool_calls) + + # Dump the chat completion to a dict + response_dict = anthropic_response.model_dump() + + # create meta to match the openai format + message._meta.update( + { + "model": response_dict.get("model", None), + "index": 0, + "finish_reason": response_dict.get("stop_reason", None), + "usage": dict(response_dict.get("usage", {})), + } + ) + return message + + def _convert_anthropic_chunk_to_streaming_chunk(self, chunk: Any) -> StreamingChunk: + """ + Converts an Anthropic StreamEvent to a StreamingChunk. + """ + content = "" + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + content = chunk.delta.text + + return StreamingChunk(content=content, meta=chunk.model_dump()) + + def _convert_streaming_chunks_to_chat_message( + self, chunks: List[StreamingChunk], model: Optional[str] = None + ) -> ChatMessage: + """ + Converts a list of StreamingChunks to a ChatMessage. + """ + full_content = "" + tool_calls = [] + current_tool_call: Optional[Dict[str, Any]] = {} + + # loop through chunks and call the appropriate handler + for chunk in chunks: + chunk_type = chunk.meta.get("type") + if chunk_type == "content_block_start": + if chunk.meta.get("content_block", {}).get("type") == "tool_use": + delta_block = chunk.meta.get("content_block") + current_tool_call = { + "id": delta_block.get("id"), + "name": delta_block.get("name"), + "arguments": "", + } + elif chunk_type == "content_block_delta": + delta = chunk.meta.get("delta", {}) + if delta.get("type") == "text_delta": + full_content += delta.get("text", "") + elif delta.get("type") == "input_json_delta" and current_tool_call: + current_tool_call["arguments"] += delta.get("partial_json", "") + elif chunk_type == "message_delta": # noqa: SIM102 (prefer nested if statement here for readability) + if chunk.meta.get("delta", {}).get("stop_reason") == "tool_use" and current_tool_call: + try: + # arguments is a string, convert to json + tool_calls.append( + ToolCall( + id=current_tool_call.get("id"), + tool_name=str(current_tool_call.get("name")), + arguments=json.loads(current_tool_call.get("arguments", {})), + ) + ) + except json.JSONDecodeError: + logger.warning( + "Anthropic returned a malformed JSON string for tool call arguments. " + "This tool call will be skipped. Arguments: %s", + current_tool_call.get("arguments", ""), + ) + current_tool_call = None + + message = ChatMessage.from_assistant(full_content, tool_calls=tool_calls) + + # Update meta information + last_chunk_meta = chunks[-1].meta + message._meta.update( + { + "model": model, + "index": 0, + "finish_reason": last_chunk_meta.get("delta", {}).get("stop_reason", None), + "usage": last_chunk_meta.get("usage", {}), + } + ) + + return message + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Invokes the Anthropic API with the given messages and generation kwargs. + + :param messages: A list of ChatMessage instances representing the input messages. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + :param generation_kwargs: Optional arguments to pass to the Anthropic generation endpoint. + :param tools: A list of tools for which the model can prepare calls. If set, it will override + the `tools` parameter set during component initialization. + :returns: A dictionary with the following keys: + - `replies`: The responses from the model + """ + # update generation kwargs by merging with the generation kwargs passed to the run method + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + disallowed_params = set(generation_kwargs) - set(self.ALLOWED_PARAMS) + if disallowed_params: + logger.warning( + "Model parameters %s are not allowed and will be ignored. Allowed parameters are %s.", + disallowed_params, + self.ALLOWED_PARAMS, + ) + generation_kwargs = {k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS} + tools = tools or self.tools + if tools: + _check_duplicate_tool_names(tools) + + system_messages, non_system_messages = _convert_messages_to_anthropic_format(messages) + anthropic_tools = ( + [ + { + "name": tool.name, + "description": tool.description, + "input_schema": tool.parameters, + } + for tool in tools + ] + if tools + else [] + ) + + streaming_callback = streaming_callback or self.streaming_callback + + response = self.client.messages.create( + model=self.model, + messages=non_system_messages, + system=system_messages, + tools=anthropic_tools, + stream=streaming_callback is not None, + max_tokens=generation_kwargs.pop("max_tokens", 1024), + **generation_kwargs, + ) + + if isinstance(response, Stream): + chunks: List[StreamingChunk] = [] + model: Optional[str] = None + for chunk in response: + if chunk.type == "message_start": + model = chunk.message.model + elif chunk.type in [ + "content_block_start", + "content_block_delta", + "message_delta", + ]: + streaming_chunk = self._convert_anthropic_chunk_to_streaming_chunk(chunk) + chunks.append(streaming_chunk) + if streaming_callback: + streaming_callback(streaming_chunk) + + completion = self._convert_streaming_chunks_to_chat_message(chunks, model) + return {"replies": [completion]} + else: + return {"replies": [self._convert_chat_completion_to_chat_message(response)]} diff --git a/test/components/generators/anthropic/test_anthropic.py b/test/components/generators/anthropic/test_anthropic.py new file mode 100644 index 00000000..710c4b7c --- /dev/null +++ b/test/components/generators/anthropic/test_anthropic.py @@ -0,0 +1,893 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import json +import logging +import os +from unittest.mock import patch + +import pytest +from anthropic.types import Message, TextBlockParam +from haystack import Pipeline +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret + +from anthropic.types import ContentBlockDeltaEvent, MessageStartEvent, TextDelta, ContentBlockStartEvent + +from haystack_experimental.components.generators.anthropic.chat.chat_generator import ( + AnthropicChatGenerator, + _convert_messages_to_anthropic_format, +) +from haystack_experimental.dataclasses import ChatMessage, ChatRole, Tool, ToolCall +from haystack_experimental.dataclasses.chat_message import ToolCallResult + + +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_user("What's the capital of France"), + ] + + +@pytest.fixture +def mock_anthropic_completion(): + with patch("anthropic.resources.messages.Messages.create") as mock_anthropic: + completion = Message( + id="foo", + type="message", + model="claude-3-5-sonnet-20240620", + role="assistant", + content=[TextBlockParam(type="text", text="Hello! I'm Claude.")], + stop_reason="end_turn", + usage={"input_tokens": 10, "output_tokens": 20}, + ) + mock_anthropic.return_value = completion + yield mock_anthropic + + +class TestAnthropicChatGenerator: + def test_init_default(self, monkeypatch): + """ + Test the default initialization of the AnthropicChatGenerator component. + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + component = AnthropicChatGenerator() + assert component.client.api_key == "test-api-key" + assert component.model == "claude-3-5-sonnet-20240620" + assert component.streaming_callback is None + assert not component.generation_kwargs + assert component.tools is None + + def test_init_fail_wo_api_key(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component fails to initialize without an API key. + """ + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + with pytest.raises(ValueError): + AnthropicChatGenerator() + + def test_init_fail_with_duplicate_tool_names(self, monkeypatch, tools): + """ + Test that the AnthropicChatGenerator component fails to initialize with duplicate tool names. + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + AnthropicChatGenerator(tools=duplicate_tools) + + def test_init_with_parameters(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component initializes with parameters. + """ + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=lambda x: x) + + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + component = AnthropicChatGenerator( + api_key=Secret.from_token("test-api-key"), + model="claude-3-5-sonnet-20240620", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + tools=[tool], + ) + assert component.client.api_key == "test-api-key" + assert component.model == "claude-3-5-sonnet-20240620" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.tools == [tool] + + def test_init_with_parameters_and_env_vars(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component initializes with parameters and env vars. + """ + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + component = AnthropicChatGenerator( + model="claude-3-5-sonnet-20240620", + api_key=Secret.from_token("test-api-key"), + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + assert component.client.api_key == "test-api-key" + assert component.model == "claude-3-5-sonnet-20240620" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_to_dict_default(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component can be serialized to a dictionary. + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + component = AnthropicChatGenerator() + data = component.to_dict() + assert data == { + "type": "haystack_experimental.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "type": "env_var", "strict": True}, + "model": "claude-3-5-sonnet-20240620", + "streaming_callback": None, + "ignore_tools_thinking_messages": True, + "generation_kwargs": {}, + "tools": None, + }, + } + + def test_to_dict_with_parameters(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component can be serialized to a dictionary with parameters. + """ + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + monkeypatch.setenv("ENV_VAR", "test-api-key") + component = AnthropicChatGenerator( + api_key=Secret.from_env_var("ENV_VAR"), + model="claude-3-5-sonnet-20240620", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + tools=[tool], + ) + data = component.to_dict() + + assert data == { + "type": "haystack_experimental.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["ENV_VAR"], "type": "env_var", "strict": True}, + "model": "claude-3-5-sonnet-20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "ignore_tools_thinking_messages": True, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + }, + ], + }, + } + + def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component can be serialized to a dictionary with a lambda streaming callback. + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + component = AnthropicChatGenerator( + model="claude-3-5-sonnet-20240620", + streaming_callback=lambda x: x, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": "haystack_experimental.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "type": "env_var", "strict": True}, + "model": "claude-3-5-sonnet-20240620", + "ignore_tools_thinking_messages": True, + "streaming_callback": "test_anthropic.", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": None, + }, + } + + def test_from_dict(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component can be deserialized from a dictionary. + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-api-key") + data = { + "type": "haystack_experimental.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "type": "env_var", "strict": True}, + "model": "claude-3-5-sonnet-20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "tools": [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + }, + ], + }, + } + component = AnthropicChatGenerator.from_dict(data) + + assert isinstance(component, AnthropicChatGenerator) + assert component.model == "claude-3-5-sonnet-20240620" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.api_key == Secret.from_env_var("ANTHROPIC_API_KEY") + assert component.tools == [ + Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + ] + + def test_from_dict_fail_wo_env_var(self, monkeypatch): + """ + Test that the AnthropicChatGenerator component fails to deserialize from a dictionary without an API key. + """ + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + data = { + "type": "haystack_experimental.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "type": "env_var", "strict": True}, + "model": "claude-3-5-sonnet-20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + with pytest.raises(ValueError): + AnthropicChatGenerator.from_dict(data) + + def test_run_with_params(self, chat_messages, mock_anthropic_completion): + """ + Test that the AnthropicChatGenerator component can run with parameters. + """ + component = AnthropicChatGenerator( + api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} + ) + response = component.run(chat_messages) + + # Check that the component calls the Anthropic API with the correct parameters + _, kwargs = mock_anthropic_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # Check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert isinstance(response["replies"][0], ChatMessage) + assert "Hello! I'm Claude." in response["replies"][0].text + assert response["replies"][0].meta["model"] == "claude-3-5-sonnet-20240620" + assert response["replies"][0].meta["finish_reason"] == "end_turn" + + def test_check_duplicate_tool_names(self, tools): + """Test that the AnthropicChatGenerator component fails to initialize with duplicate tool names.""" + with pytest.raises(ValueError): + AnthropicChatGenerator(tools=tools + tools) + + def test_convert_anthropic_chunk_to_streaming_chunk(self): + """ + Test converting Anthropic stream events to Haystack StreamingChunks + """ + component = AnthropicChatGenerator(api_key=Secret.from_token("test-api-key")) + + # Test text delta chunk + text_delta_chunk = ContentBlockDeltaEvent( + type="content_block_delta", index=0, delta=TextDelta(type="text_delta", text="Hello, world!") + ) + streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(text_delta_chunk) + assert streaming_chunk.content == "Hello, world!" + assert streaming_chunk.meta == { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Hello, world!"}, + } + + # Test non-text chunk (should have empty content) + message_start_chunk = MessageStartEvent( + type="message_start", + message={ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-5-sonnet-20240620", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 25, "output_tokens": 1}, + }, + ) + streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(message_start_chunk) + assert streaming_chunk.content == "" + assert streaming_chunk.meta == { + "type": "message_start", + "message": { + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-3-5-sonnet-20240620", + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 25, "output_tokens": 1}, + }, + } + + # Test tool use chunk (should have empty content) + tool_use_chunk = ContentBlockStartEvent( + type="content_block_start", + index=1, + content_block={"type": "tool_use", "id": "toolu_123", "name": "weather", "input": {"city": "Paris"}}, + ) + streaming_chunk = component._convert_anthropic_chunk_to_streaming_chunk(tool_use_chunk) + assert streaming_chunk.content == "" + assert streaming_chunk.meta == { + "type": "content_block_start", + "index": 1, + "content_block": {"type": "tool_use", "id": "toolu_123", "name": "weather", "input": {"city": "Paris"}}, + } + + def test_convert_streaming_chunks_to_chat_message(self): + """ + Test converting streaming chunks to a chat message with tool calls + """ + # Create a sequence of streaming chunks that simulate Anthropic's response + chunks = [ + # Initial text content + StreamingChunk( + content="", + meta={"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}, + ), + StreamingChunk( + content="Let me check", + meta={ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Let me check"}, + }, + ), + StreamingChunk( + content=" the weather", + meta={ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": " the weather"}, + }, + ), + StreamingChunk(content="", meta={"type": "content_block_stop", "index": 0}), + # Tool use content + StreamingChunk( + content="", + meta={ + "type": "content_block_start", + "index": 1, + "content_block": {"type": "tool_use", "id": "toolu_123", "name": "weather", "input": {}}, + }, + ), + StreamingChunk( + content="", + meta={ + "type": "content_block_delta", + "index": 1, + "delta": {"type": "input_json_delta", "partial_json": '{"city":'}, + }, + ), + StreamingChunk( + content="", + meta={ + "type": "content_block_delta", + "index": 1, + "delta": {"type": "input_json_delta", "partial_json": ' "Paris"}'}, + }, + ), + StreamingChunk(content="", meta={"type": "content_block_stop", "index": 1}), + # Final message delta + StreamingChunk( + content="", + meta={ + "type": "message_delta", + "delta": {"stop_reason": "tool_use", "stop_sequence": None}, + "usage": {"output_tokens": 40}, + }, + ), + ] + + component = AnthropicChatGenerator(api_key=Secret.from_token("test-api-key")) + message = component._convert_streaming_chunks_to_chat_message(chunks, model="claude-3-sonnet") + + # Verify the message content + assert message.text == "Let me check the weather" + + # Verify tool calls + assert len(message.tool_calls) == 1 + tool_call = message.tool_calls[0] + assert tool_call.id == "toolu_123" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + + # Verify meta information + assert message._meta["model"] == "claude-3-sonnet" + assert message._meta["index"] == 0 + assert message._meta["finish_reason"] == "tool_use" + assert message._meta["usage"] == {"output_tokens": 40} + + def test_convert_streaming_chunks_to_chat_message_malformed_json(self, caplog): + """ + Test converting streaming chunks with malformed JSON in tool arguments (increases coverage) + """ + chunks = [ + # Initial text content + StreamingChunk( + content="", + meta={"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}, + ), + StreamingChunk( + content="Let me check the weather", + meta={ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Let me check the weather"}, + }, + ), + StreamingChunk(content="", meta={"type": "content_block_stop", "index": 0}), + # Tool use content with malformed JSON + StreamingChunk( + content="", + meta={ + "type": "content_block_start", + "index": 1, + "content_block": {"type": "tool_use", "id": "toolu_123", "name": "weather", "input": {}}, + }, + ), + StreamingChunk( + content="", + meta={ + "type": "content_block_delta", + "index": 1, + "delta": {"type": "input_json_delta", "partial_json": '{"city":'}, + }, + ), + StreamingChunk( + content="", + meta={ + "type": "content_block_delta", + "index": 1, + "delta": { + "type": "input_json_delta", + "partial_json": ' "Paris', # Missing closing quote and brace, malformed JSON + }, + }, + ), + StreamingChunk(content="", meta={"type": "content_block_stop", "index": 1}), + # Final message delta + StreamingChunk( + content="", + meta={ + "type": "message_delta", + "delta": {"stop_reason": "tool_use", "stop_sequence": None}, + "usage": {"output_tokens": 40}, + }, + ), + ] + + component = AnthropicChatGenerator(api_key=Secret.from_token("test-api-key")) + message = component._convert_streaming_chunks_to_chat_message(chunks, model="claude-3-sonnet") + + # Verify the message content is preserve + assert message.text == "Let me check the weather" + + # But the tool_calls are empty + assert len(message.tool_calls) == 0 + + # and we have logged a warning + with caplog.at_level(logging.WARNING): + assert "Anthropic returned a malformed JSON string" in caplog.text + + def test_serde_in_pipeline(self): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = AnthropicChatGenerator( + api_key=Secret.from_env_var("ANTHROPIC_API_KEY", strict=False), + model="claude-3-5-sonnet-20240620", + generation_kwargs={"temperature": 0.6}, + tools=[tool], + ) + + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": "haystack_experimental.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", + "init_parameters": { + "api_key": {"type": "env_var", "env_vars": ["ANTHROPIC_API_KEY"], "strict": False}, + "model": "claude-3-5-sonnet-20240620", + "generation_kwargs": {"temperature": 0.6}, + "ignore_tools_thinking_messages": True, + "streaming_callback": None, + "tools": [ + { + "name": "name", + "description": "description", + "parameters": {"x": {"type": "string"}}, + "function": "builtins.print", + } + ], + }, + } + }, + "connections": [], + } + + pipeline_yaml = pipeline.dumps() + + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY", None), + reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.", + ) + @pytest.mark.integration + def test_live_run(self): + """ + Integration test that the AnthropicChatGenerator component can run with default parameters. + """ + component = AnthropicChatGenerator() + results = component.run(messages=[ChatMessage.from_user("What's the capital of France?")]) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + assert "claude-3-5-sonnet-20240620" in message.meta["model"] + assert message.meta["finish_reason"] == "end_turn" + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY", None), + reason="Export an env var called ANTHROPIC_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_streaming(self): + """ + Integration test that the AnthropicChatGenerator component can run with streaming. + """ + + class Callback: + def __init__(self): + self.responses = "" + self.counter = 0 + + def __call__(self, chunk: StreamingChunk) -> None: + self.counter += 1 + self.responses += chunk.content if chunk.content else "" + + callback = Callback() + component = AnthropicChatGenerator(streaming_callback=callback) + results = component.run([ChatMessage.from_user("What's the capital of France?")]) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + + assert "claude-3-5-sonnet-20240620" in message.meta["model"] + assert message.meta["finish_reason"] == "end_turn" + + assert callback.counter > 1 + assert "Paris" in callback.responses + + def test_convert_message_to_anthropic_format(self): + """ + Test that the AnthropicChatGenerator component can convert a ChatMessage to Anthropic format. + """ + messages = [ChatMessage.from_system("You are good assistant")] + assert _convert_messages_to_anthropic_format(messages) == ( + [{"type": "text", "text": "You are good assistant"}], + [], + ) + + messages = [ChatMessage.from_user("I have a question")] + assert _convert_messages_to_anthropic_format(messages) == ( + [], + [{"role": "user", "content": [{"type": "text", "text": "I have a question"}]}], + ) + + messages = [ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"})] + assert _convert_messages_to_anthropic_format(messages) == ( + [], + [{"role": "assistant", "content": [{"type": "text", "text": "I have an answer"}]}], + ) + + messages = [ + ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] + ) + ] + result = _convert_messages_to_anthropic_format(messages) + assert result == ( + [], + [ + { + "role": "assistant", + "content": [{"type": "tool_use", "id": "123", "name": "weather", "input": {"city": "Paris"}}], + } + ], + ) + + tool_result = json.dumps({"weather": "sunny", "temperature": "25"}) + messages = [ + ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ) + ] + assert _convert_messages_to_anthropic_format(messages) == ( + [], + [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "123", + "content": [{"type": "text", "text": '{"weather": "sunny", "temperature": "25"}'}], + "is_error": False, + } + ], + } + ], + ) + + messages = [ + ChatMessage.from_assistant( + text="For that I'll need to check the weather", + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})], + ) + ] + result = _convert_messages_to_anthropic_format(messages) + assert result == ( + [], + [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "For that I'll need to check the weather"}, + {"type": "tool_use", "id": "123", "name": "weather", "input": {"city": "Paris"}}, + ], + } + ], + ) + + def test_convert_message_to_anthropic_format_complex(self): + """ + Test that the AnthropicChatGenerator component can convert a complex sequence of ChatMessages to Anthropic format. + In particular, we check that different tool results are packed in a single dictionary with role=user. + """ + + messages = [ + ChatMessage.from_system("You are good assistant"), + ChatMessage.from_user("What's the weather like in Paris? And how much is 2+2?"), + ChatMessage.from_assistant( + text="", + tool_calls=[ + ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}), + ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}), + ], + ), + ChatMessage.from_tool( + tool_result="22° C", origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ), + ChatMessage.from_tool( + tool_result="4", origin=ToolCall(id="456", tool_name="math", arguments={"expression": "2+2"}) + ), + ] + + system_messages, non_system_messages = _convert_messages_to_anthropic_format(messages) + + assert system_messages == [{"type": "text", "text": "You are good assistant"}] + assert non_system_messages == [ + { + "role": "user", + "content": [{"type": "text", "text": "What's the weather like in Paris? And how much is 2+2?"}], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": ""}, + {"type": "tool_use", "id": "123", "name": "weather", "input": {"city": "Paris"}}, + {"type": "tool_use", "id": "456", "name": "math", "input": {"expression": "2+2"}}, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "123", + "content": [{"type": "text", "text": "22° C"}], + "is_error": False, + }, + { + "type": "tool_result", + "tool_use_id": "456", + "content": [{"type": "text", "text": "4"}], + "is_error": False, + }, + ], + }, + ] + + def test_convert_message_to_anthropic_invalid(self): + """ + Test that the AnthropicChatGenerator component fails to convert an invalid ChatMessage to Anthropic format. + """ + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + _convert_messages_to_anthropic_format([message]) + + tool_call_null_id = ToolCall(id=None, tool_name="weather", arguments={"city": "Paris"}) + message = ChatMessage.from_assistant(tool_calls=[tool_call_null_id]) + with pytest.raises(ValueError): + _convert_messages_to_anthropic_format([message]) + + message = ChatMessage.from_tool(tool_result="result", origin=tool_call_null_id) + with pytest.raises(ValueError): + _convert_messages_to_anthropic_format([message]) + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY", None), + reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools(self, tools): + """ + Integration test that the AnthropicChatGenerator component can run with tools. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AnthropicChatGenerator(tools=tools) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.id is not None + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_use" + + new_messages = initial_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] + # the model tends to make tool calls if provided with tools, so we don't pass them here + results = component.run(new_messages, generation_kwargs={"max_tokens": 50}) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY", None), + reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools_streaming(self, tools): + """ + Integration test that the AnthropicChatGenerator component can run with tools and streaming. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AnthropicChatGenerator(tools=tools, streaming_callback=print_streaming_chunk) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + + # this is Antropic thinking message prior to tool call + assert message.text is not None + assert "weather" in message.text.lower() + assert "paris" in message.text.lower() + + # now we have the tool call + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.id is not None + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_use" + + new_messages = initial_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] + results = component.run(new_messages) + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY", None), + reason="Export an env var called ANTHROPIC_API_KEY containing the Anthropic API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_parallel_tools(self, tools): + """ + Integration test that the AnthropicChatGenerator component can run with parallel tools. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")] + component = AnthropicChatGenerator(tools=tools) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + + # now we have the tool call + assert len(message.tool_calls) == 2 + tool_call_paris = message.tool_calls[0] + assert isinstance(tool_call_paris, ToolCall) + assert tool_call_paris.id is not None + assert tool_call_paris.tool_name == "weather" + assert tool_call_paris.arguments == {"city": "Paris"} or tool_call_paris.arguments == {"city": "Berlin"} + assert message.meta["finish_reason"] == "tool_use" + + tool_call_berlin = message.tool_calls[1] + assert isinstance(tool_call_berlin, ToolCall) + assert tool_call_berlin.id is not None + assert tool_call_berlin.tool_name == "weather" + assert tool_call_berlin.arguments == {"city": "Berlin"} or tool_call_berlin.arguments == {"city": "Paris"} + + # Anthropic expects results from both tools in the same message + # see https://docs.anthropic.com/en/docs/build-with-claude/tool-use#handling-tool-use-and-tool-result-content-blocks + # the docs state: + # [optional] Continue the conversation by sending a new message with the role of user, and a content block containing + # the tool_result type and the following information: + # tool_use_id: The id of the tool use request this is a result for. + # content: The result of the tool, as a string (e.g. "content": "15 degrees") or list of + # nested content blocks (e.g. "content": [{"type": "text", "text": "15 degrees"}]). + # These content blocks can use the text or image types. + # is_error (optional): Set to true if the tool execution resulted in an error. + new_messages = initial_messages + [ + message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call_paris, error=False), + ChatMessage.from_tool(tool_result="12° C", origin=tool_call_berlin, error=False), + ] + + # Response from the model contains results from both tools + results = component.run(new_messages) + message = results["replies"][0] + assert not message.tool_calls + assert len(message.text) > 0 + assert "paris" in message.text.lower() + assert "berlin" in message.text.lower() + assert "22°" in message.text + assert "12°" in message.text + assert message.meta["finish_reason"] == "end_turn"