Skip to content

Commit

Permalink
adapt to Ollama client 0.4.0 (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Nov 22, 2024
1 parent aadc16d commit 4c1dd83
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# pylint: disable=import-error
from haystack_integrations.components.generators.ollama import OllamaChatGenerator as OllamaChatGeneratorBase

from ollama import ChatResponse


# The following code block ensures that:
# - we reuse existing code where possible
Expand Down Expand Up @@ -175,11 +177,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator":

return default_from_dict(cls, data)

def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage:
def _build_message_from_ollama_response(self, ollama_response: "ChatResponse") -> ChatMessage:
"""
Converts the non-streaming response from the Ollama API to a ChatMessage.
"""
ollama_message = ollama_response["message"]
response_dict = ollama_response.model_dump()

ollama_message = response_dict["message"]

text = ollama_message["content"]

Expand All @@ -192,7 +196,7 @@ def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -

message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls)

message.meta.update({key: value for key, value in ollama_response.items() if key != "message"})
message.meta.update({key: value for key, value in response_dict.items() if key != "message"})
return message

def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ extra-dependencies = [
"fastapi",
# Tools support
"jsonschema",
"ollama-haystack>=1.1.0",
"ollama-haystack>=2.0",
# Async
"opensearch-haystack",
"opensearch-py[async]",
Expand Down
110 changes: 55 additions & 55 deletions test/components/generators/ollama/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import StreamingChunk
from ollama._types import ResponseError
from ollama._types import ResponseError, ChatResponse

from haystack_experimental.dataclasses import (
ChatMessage,
Expand Down Expand Up @@ -225,18 +225,18 @@ def test_from_dict(self):
def test_build_message_from_ollama_response(self):
model = "some_model"

ollama_response = {
"model": model,
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {"role": "assistant", "content": "Hello! How are you today?"},
"done": True,
"total_duration": 5191566416,
"load_duration": 2154458,
"prompt_eval_count": 26,
"prompt_eval_duration": 383809000,
"eval_count": 298,
"eval_duration": 4799921000,
}
ollama_response = ChatResponse(
model=model,
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "Hello! How are you today?"},
done=True,
total_duration=5191566416,
load_duration=2154458,
prompt_eval_count=26,
prompt_eval_duration=383809000,
eval_count=298,
eval_duration=4799921000,
)

observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response)

Expand All @@ -246,10 +246,10 @@ def test_build_message_from_ollama_response(self):
def test_build_message_from_ollama_response_with_tools(self):
model = "some_model"

ollama_response = {
"model": model,
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {
ollama_response = ChatResponse(
model=model,
created_at="2023-12-12T14:13:43.416799Z",
message={
"role": "assistant",
"content": "",
"tool_calls": [
Expand All @@ -261,14 +261,14 @@ def test_build_message_from_ollama_response_with_tools(self):
}
],
},
"done": True,
"total_duration": 5191566416,
"load_duration": 2154458,
"prompt_eval_count": 26,
"prompt_eval_duration": 383809000,
"eval_count": 298,
"eval_duration": 4799921000,
}
done=True,
total_duration=5191566416,
load_duration=2154458,
prompt_eval_count=26,
prompt_eval_duration=383809000,
eval_count=298,
eval_duration=4799921000,
)

observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response)

Expand All @@ -283,21 +283,21 @@ def test_build_message_from_ollama_response_with_tools(self):
def test_run(self, mock_client):
generator = OllamaChatGenerator()

mock_response = {
"model": "llama3.2",
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {
mock_response = ChatResponse(
model="llama3.2",
created_at="2023-12-12T14:13:43.416799Z",
message={
"role": "assistant",
"content": "Fine. How can I help you today?",
},
"done": True,
"total_duration": 5191566416,
"load_duration": 2154458,
"prompt_eval_count": 26,
"prompt_eval_duration": 383809000,
"eval_count": 298,
"eval_duration": 4799921000,
}
done=True,
total_duration=5191566416,
load_duration=2154458,
prompt_eval_count=26,
prompt_eval_duration=383809000,
eval_count=298,
eval_duration=4799921000,
)

mock_client_instance = mock_client.return_value
mock_client_instance.chat.return_value = mock_response
Expand Down Expand Up @@ -330,24 +330,24 @@ def streaming_callback(chunk: StreamingChunk) -> None:

mock_response = iter(
[
{
"model": "llama3.2",
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {"role": "assistant", "content": "first chunk "},
"done": False,
},
{
"model": "llama3.2",
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {"role": "assistant", "content": "second chunk"},
"done": True,
"total_duration": 4883583458,
"load_duration": 1334875,
"prompt_eval_count": 26,
"prompt_eval_duration": 342546000,
"eval_count": 282,
"eval_duration": 4535599000,
},
ChatResponse(
model="llama3.2",
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "first chunk "},
done=False,
),
ChatResponse(
model="llama3.2",
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "second chunk"},
done=True,
total_duration=4883583458,
load_duration=1334875,
prompt_eval_count=26,
prompt_eval_duration=342546000,
eval_count=282,
eval_duration=4535599000,
),
]
)

Expand Down

0 comments on commit 4c1dd83

Please sign in to comment.