Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama tool streaming update phi 2122 #1494

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cookbook/vectordb/lance_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# By default, it stores data in /tmp/lancedb
vector_db = LanceDb(
table_name="recipes",
uri="/tmp/lancedb" # You can change this path to store data elsewhere
uri="/tmp/lancedb", # You can change this path to store data elsewhere
)

# Create knowledge base
Expand All @@ -21,4 +21,4 @@

# Create and use the agent
agent = Agent(knowledge_base=knowledge_base, use_tools=True, show_tool_calls=True)
agent.print_response("How to make Tom Kha Gai", markdown=True)
agent.print_response("How to make Tom Kha Gai", markdown=True)
108 changes: 35 additions & 73 deletions phi/model/ollama/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

from dataclasses import dataclass, field
from typing import Optional, List, Iterator, Dict, Any, Mapping, Union, Tuple
from typing import Optional, List, Iterator, Dict, Any, Mapping, Union

from phi.model.base import Model
from phi.model.message import Message
Expand Down Expand Up @@ -531,31 +531,6 @@ def _handle_stream_tool_calls(

self._format_function_call_results(function_call_results, messages)

def _handle_tool_call_chunk(self, content, tool_call_buffer, message_data) -> Tuple[str, bool]:
"""
Handle a tool call chunk for response stream.

Args:
content: The content of the tool call.
tool_call_buffer: The tool call buffer.
message_data: The message data.

Returns:
Tuple[str, bool]: The tool call buffer and a boolean indicating if the tool call is complete.
"""
tool_call_buffer += content
brace_count = tool_call_buffer.count("{") - tool_call_buffer.count("}")

if brace_count == 0:
try:
tool_call_data = json.loads(tool_call_buffer)
message_data.tool_call_blocks.append(tool_call_data)
except json.JSONDecodeError:
logger.error("Failed to parse tool call JSON.")
return "", False

return tool_call_buffer, True

def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]:
"""
Generate a streaming response from Ollama.
Expand All @@ -569,52 +544,39 @@ def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]:
logger.debug("---------- Ollama Response Start ----------")
self._log_messages(messages)
message_data = MessageData()
ignored_content = frozenset(["json", "\n", ";", ";\n"])
metrics: Metrics = Metrics()

# -*- Generate response
metrics.response_timer.start()
for response in self.invoke_stream(messages=messages):
# logger.debug(f"Response: {response.get('message', {}).get('content', '')}")
message_data.response_message = response.get("message", {})
if message_data.response_message:
metrics.output_tokens += 1
if metrics.output_tokens == 1:
metrics.time_to_first_token = metrics.response_timer.elapsed

message_data.response_content_chunk = message_data.response_message.get("content", "").strip("`")

if message_data.response_content_chunk:
if message_data.in_tool_call:
message_data.tool_call_chunk, message_data.in_tool_call = self._handle_tool_call_chunk(
message_data.response_content_chunk, message_data.tool_call_chunk, message_data
)
elif message_data.response_content_chunk.strip().startswith("{"):
message_data.in_tool_call = True
message_data.tool_call_chunk, message_data.in_tool_call = self._handle_tool_call_chunk(
message_data.response_content_chunk, message_data.tool_call_chunk, message_data
)
else:
if message_data.response_content_chunk not in ignored_content:
yield ModelResponse(content=message_data.response_content_chunk)
message_data.response_content += message_data.response_content_chunk
message_data.response_content_chunk = message_data.response_message.get("content", "")

yield ModelResponse(content=message_data.response_content_chunk)
message_data.response_content += message_data.response_content_chunk

message_data.tool_call_blocks = message_data.response_message.get("tool_calls") # type: ignore
if message_data.tool_call_blocks is not None:
for block in message_data.tool_call_blocks:
tool_call = block.get("function")
tool_name = tool_call.get("name")
tool_args = tool_call.get("arguments")

function_def = {
"name": tool_name,
"arguments": json.dumps(tool_args) if tool_args is not None else None,
}
message_data.tool_calls.append({"type": "function", "function": function_def})

if response.get("done"):
message_data.response_usage = response
metrics.response_timer.stop()

# Format tool calls
if message_data.tool_call_blocks is not None:
for block in message_data.tool_call_blocks:
tool_name = block.get("name")
tool_args = block.get("parameters")

function_def = {
"name": tool_name,
"arguments": json.dumps(tool_args) if tool_args is not None else None,
}
message_data.tool_calls.append({"type": "function", "function": function_def})

# -*- Create assistant message
assistant_message = Message(role="assistant", content=message_data.response_content)

Expand Down Expand Up @@ -652,7 +614,6 @@ async def aresponse_stream(self, messages: List[Message]) -> Any:
logger.debug("---------- Ollama Async Response Start ----------")
self._log_messages(messages)
message_data = MessageData()
ignored_content = frozenset(["json", "\n", ";", ";\n"])
metrics: Metrics = Metrics()

# -*- Generate response
Expand All @@ -664,22 +625,23 @@ async def aresponse_stream(self, messages: List[Message]) -> Any:
if metrics.output_tokens == 1:
metrics.time_to_first_token = metrics.response_timer.elapsed

message_data.response_content_chunk = message_data.response_message.get("content", "").strip("`")

if message_data.response_content_chunk:
if message_data.in_tool_call:
message_data.tool_call_chunk, message_data.in_tool_call = self._handle_tool_call_chunk(
message_data.response_content_chunk, message_data.tool_call_chunk, message_data
)
elif message_data.response_content_chunk.strip().startswith("{"):
message_data.in_tool_call = True
message_data.tool_call_chunk, message_data.in_tool_call = self._handle_tool_call_chunk(
message_data.response_content_chunk, message_data.tool_call_chunk, message_data
)
else:
if message_data.response_content_chunk not in ignored_content:
yield ModelResponse(content=message_data.response_content_chunk)
message_data.response_content += message_data.response_content_chunk
message_data.response_content_chunk = message_data.response_message.get("content", "")

yield ModelResponse(content=message_data.response_content_chunk)
message_data.response_content += message_data.response_content_chunk

message_data.tool_call_blocks = message_data.response_message.get("tool_calls")
if message_data.tool_call_blocks is not None:
for block in message_data.tool_call_blocks:
tool_call = block.get("function")
tool_name = tool_call.get("name")
tool_args = tool_call.get("arguments")

function_def = {
"name": tool_name,
"arguments": json.dumps(tool_args) if tool_args is not None else None,
}
message_data.tool_calls.append({"type": "function", "function": function_def})

if response.get("done"):
message_data.response_usage = response
Expand Down