diff --git a/haystack_experimental/components/generators/ollama/chat/chat_generator.py b/haystack_experimental/components/generators/ollama/chat/chat_generator.py index dae76ec9..06441253 100644 --- a/haystack_experimental/components/generators/ollama/chat/chat_generator.py +++ b/haystack_experimental/components/generators/ollama/chat/chat_generator.py @@ -16,6 +16,8 @@ # pylint: disable=import-error from haystack_integrations.components.generators.ollama import OllamaChatGenerator as OllamaChatGeneratorBase + from ollama import ChatResponse + # The following code block ensures that: # - we reuse existing code where possible @@ -175,11 +177,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator": return default_from_dict(cls, data) - def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage: + def _build_message_from_ollama_response(self, ollama_response: "ChatResponse") -> ChatMessage: """ Converts the non-streaming response from the Ollama API to a ChatMessage. """ - ollama_message = ollama_response["message"] + response_dict = ollama_response.model_dump() + + ollama_message = response_dict["message"] text = ollama_message["content"] @@ -192,7 +196,7 @@ def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) - message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls) - message.meta.update({key: value for key, value in ollama_response.items() if key != "message"}) + message.meta.update({key: value for key, value in response_dict.items() if key != "message"}) return message def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: diff --git a/pyproject.toml b/pyproject.toml index 45324448..c21b26f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ extra-dependencies = [ "fastapi", # Tools support "jsonschema", - "ollama-haystack>=1.1.0", + "ollama-haystack>=2.0", # Async "opensearch-haystack", "opensearch-py[async]", diff --git a/test/components/generators/ollama/test_chat_generator.py b/test/components/generators/ollama/test_chat_generator.py index 42bc3796..95f82bb3 100644 --- a/test/components/generators/ollama/test_chat_generator.py +++ b/test/components/generators/ollama/test_chat_generator.py @@ -5,7 +5,7 @@ from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import StreamingChunk -from ollama._types import ResponseError +from ollama._types import ResponseError, ChatResponse from haystack_experimental.dataclasses import ( ChatMessage, @@ -225,18 +225,18 @@ def test_from_dict(self): def test_build_message_from_ollama_response(self): model = "some_model" - ollama_response = { - "model": model, - "created_at": "2023-12-12T14:13:43.416799Z", - "message": {"role": "assistant", "content": "Hello! How are you today?"}, - "done": True, - "total_duration": 5191566416, - "load_duration": 2154458, - "prompt_eval_count": 26, - "prompt_eval_duration": 383809000, - "eval_count": 298, - "eval_duration": 4799921000, - } + ollama_response = ChatResponse( + model=model, + created_at="2023-12-12T14:13:43.416799Z", + message={"role": "assistant", "content": "Hello! How are you today?"}, + done=True, + total_duration=5191566416, + load_duration=2154458, + prompt_eval_count=26, + prompt_eval_duration=383809000, + eval_count=298, + eval_duration=4799921000, + ) observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) @@ -246,10 +246,10 @@ def test_build_message_from_ollama_response(self): def test_build_message_from_ollama_response_with_tools(self): model = "some_model" - ollama_response = { - "model": model, - "created_at": "2023-12-12T14:13:43.416799Z", - "message": { + ollama_response = ChatResponse( + model=model, + created_at="2023-12-12T14:13:43.416799Z", + message={ "role": "assistant", "content": "", "tool_calls": [ @@ -261,14 +261,14 @@ def test_build_message_from_ollama_response_with_tools(self): } ], }, - "done": True, - "total_duration": 5191566416, - "load_duration": 2154458, - "prompt_eval_count": 26, - "prompt_eval_duration": 383809000, - "eval_count": 298, - "eval_duration": 4799921000, - } + done=True, + total_duration=5191566416, + load_duration=2154458, + prompt_eval_count=26, + prompt_eval_duration=383809000, + eval_count=298, + eval_duration=4799921000, + ) observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) @@ -283,21 +283,21 @@ def test_build_message_from_ollama_response_with_tools(self): def test_run(self, mock_client): generator = OllamaChatGenerator() - mock_response = { - "model": "llama3.2", - "created_at": "2023-12-12T14:13:43.416799Z", - "message": { + mock_response = ChatResponse( + model="llama3.2", + created_at="2023-12-12T14:13:43.416799Z", + message={ "role": "assistant", "content": "Fine. How can I help you today?", }, - "done": True, - "total_duration": 5191566416, - "load_duration": 2154458, - "prompt_eval_count": 26, - "prompt_eval_duration": 383809000, - "eval_count": 298, - "eval_duration": 4799921000, - } + done=True, + total_duration=5191566416, + load_duration=2154458, + prompt_eval_count=26, + prompt_eval_duration=383809000, + eval_count=298, + eval_duration=4799921000, + ) mock_client_instance = mock_client.return_value mock_client_instance.chat.return_value = mock_response @@ -330,24 +330,24 @@ def streaming_callback(chunk: StreamingChunk) -> None: mock_response = iter( [ - { - "model": "llama3.2", - "created_at": "2023-12-12T14:13:43.416799Z", - "message": {"role": "assistant", "content": "first chunk "}, - "done": False, - }, - { - "model": "llama3.2", - "created_at": "2023-12-12T14:13:43.416799Z", - "message": {"role": "assistant", "content": "second chunk"}, - "done": True, - "total_duration": 4883583458, - "load_duration": 1334875, - "prompt_eval_count": 26, - "prompt_eval_duration": 342546000, - "eval_count": 282, - "eval_duration": 4535599000, - }, + ChatResponse( + model="llama3.2", + created_at="2023-12-12T14:13:43.416799Z", + message={"role": "assistant", "content": "first chunk "}, + done=False, + ), + ChatResponse( + model="llama3.2", + created_at="2023-12-12T14:13:43.416799Z", + message={"role": "assistant", "content": "second chunk"}, + done=True, + total_duration=4883583458, + load_duration=1334875, + prompt_eval_count=26, + prompt_eval_duration=342546000, + eval_count=282, + eval_duration=4535599000, + ), ] )