Skip to content

Commit

Permalink
Feat: IBM watsonx.ai Chat integration (#16589)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wojciech-Rebisz authored Oct 18, 2024
1 parent bdd4dd2 commit f33ce66
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from typing import Any, Dict, Optional, Sequence, Union, Tuple
from typing import Any, Dict, Optional, Sequence, Union, Tuple, List

from ibm_watsonx_ai import Credentials, APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames, GenChatParamsMetaNames

from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
MessageRole,
)
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
from llama_index.core.bridge.pydantic import (
Expand All @@ -19,26 +22,35 @@
PrivateAttr,
)

from llama_index.core.llms.llm import ToolSelection

# Import SecretStr directly from pydantic
# since there is not one in llama_index.core.bridge.pydantic
from pydantic import SecretStr

from llama_index.core.llms.function_calling import FunctionCallingLLM

from llama_index.core.callbacks import CallbackManager
from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
from llama_index.core.base.llms.generic_utils import (
completion_to_chat_decorator,
stream_completion_to_chat_decorator,
)
from llama_index.core.llms.custom import CustomLLM

from llama_index.core.llms.utils import parse_partial_json

from llama_index.llms.ibm.utils import (
resolve_watsonx_credentials,
to_watsonx_message_dict,
from_watsonx_message,
update_tool_calls,
)

# default max tokens determined by service
DEFAULT_MAX_TOKENS = 20


class WatsonxLLM(CustomLLM):
class WatsonxLLM(FunctionCallingLLM):
"""
IBM watsonx.ai large language models.
Expand Down Expand Up @@ -305,6 +317,11 @@ def sample_generation_text_params(self) -> Dict[str, Any]:
"""Example of Model generation text kwargs that a user can pass to the model."""
return GenTextParamsMetaNames().get_example_values()

@property
def sample_chat_generation_params(self) -> Dict[str, Any]:
"""Example of Model chat generation kwargs that a user can pass to the model."""
return GenChatParamsMetaNames().get_example_values()

def _split_generation_params(
self, data: Dict[str, Any]
) -> Tuple[Dict[str, Any] | None, Dict[str, Any]]:
Expand All @@ -319,6 +336,19 @@ def _split_generation_params(
kwargs.update({key: value})
return params if params else None, kwargs

def _split_chat_generation_params(
self, data: Dict[str, Any]
) -> Tuple[Dict[str, Any] | None, Dict[str, Any]]:
params = {}
kwargs = {}
sample_generation_kwargs_keys = set(self.sample_chat_generation_params.keys())
for key, value in data.items():
if key in sample_generation_kwargs_keys:
params.update({key: value})
else:
kwargs.update({key: value})
return params if params else None, kwargs

@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
Expand All @@ -335,6 +365,12 @@ def complete(
raw=response,
)

@llm_completion_callback()
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
return self.complete(prompt, formatted=formatted, **kwargs)

@llm_completion_callback()
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
Expand Down Expand Up @@ -365,16 +401,214 @@ def gen() -> CompletionResponseGen:

return gen()

@llm_completion_callback()
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseAsyncGen:
async def gen() -> CompletionResponseAsyncGen:
for message in self.stream_complete(prompt, formatted=formatted, **kwargs):
yield message

# NOTE: convert generator to async generator
return gen()

def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
message_dicts = [to_watsonx_message_dict(message) for message in messages]

params, generation_kwargs = self._split_chat_generation_params(kwargs)
response = self._model.chat(
messages=message_dicts,
params=params,
tools=generation_kwargs.get("tools"),
tool_choice=generation_kwargs.get("tool_choice"),
tool_choice_option=generation_kwargs.get("tool_choice_option"),
)

wx_message = response["choices"][0]["message"]
message = from_watsonx_message(wx_message)

return ChatResponse(
message=message,
raw=response,
)

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
chat_fn = completion_to_chat_decorator(self.complete)
if kwargs.get("use_completions"):
chat_fn = completion_to_chat_decorator(self.complete)
else:
chat_fn = self._chat

return chat_fn(messages, **kwargs)

@llm_chat_callback()
async def achat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
return self.chat(messages, **kwargs)

def _stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
message_dicts = [to_watsonx_message_dict(message) for message in messages]

params, generation_kwargs = self._split_chat_generation_params(kwargs)
stream_response = self._model.chat_stream(
messages=message_dicts,
params=params,
tools=generation_kwargs.get("tools"),
tool_choice=generation_kwargs.get("tool_choice"),
tool_choice_option=generation_kwargs.get("tool_choice_option"),
)

def stream_gen() -> ChatResponseGen:
content = ""
role = None
tool_calls = []

for response in stream_response:
tools_available = False
wx_message = response["choices"][0]["delta"]

role = wx_message.get("role") or role or MessageRole.ASSISTANT
delta = wx_message.get("content", "")
content += delta

if "tool_calls" in wx_message:
tools_available = True

additional_kwargs = {}
if tools_available:
tool_calls = update_tool_calls(tool_calls, wx_message["tool_calls"])
if tool_calls:
additional_kwargs["tool_calls"] = tool_calls

yield ChatResponse(
message=ChatMessage(
role=role,
content=content,
additional_kwargs=additional_kwargs,
),
delta=delta,
raw=response,
additional_kwargs=self._get_response_token_counts(response),
)

return stream_gen()

@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
chat_stream_fn = stream_completion_to_chat_decorator(self.stream_complete)
if kwargs.get("use_completions"):
chat_stream_fn = stream_completion_to_chat_decorator(self.stream_complete)
else:
chat_stream_fn = self._stream_chat

return chat_stream_fn(messages, **kwargs)

@llm_chat_callback()
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
async def gen() -> ChatResponseAsyncGen:
for message in self.stream_chat(messages, **kwargs):
yield message

# NOTE: convert generator to async generator
return gen()

def _prepare_chat_with_tools(
self,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Predict and call the tool."""
# watsonx uses the same openai tool format
tool_specs = [tool.metadata.to_openai_tool() for tool in tools]

if isinstance(user_msg, str):
user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)

messages = chat_history or []
if user_msg:
messages.append(user_msg)

chat_with_tools_payload = {
"messages": messages,
"tools": tool_specs or None,
**kwargs,
}
if tool_choice is not None:
chat_with_tools_payload.update(
{"tool_choice": {"type": "function", "function": {"name": tool_choice}}}
)
return chat_with_tools_payload

def get_tool_calls_from_response(
self,
response: ChatResponse,
error_on_no_tool_call: bool = True,
**kwargs: Any,
) -> List[ToolSelection]:
"""Predict and call the tool."""
tool_calls = response.message.additional_kwargs.get("tool_calls", [])

if len(tool_calls) < 1:
if error_on_no_tool_call:
raise ValueError(
f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
)
else:
return []

tool_selections = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
raise ValueError("Invalid tool_call object")
if tool_call.get("type") != "function":
raise ValueError("Invalid tool type. Unsupported by watsonx.ai")

# this should handle both complete and partial jsons
try:
argument_dict = parse_partial_json(
tool_call.get("function", {}).get("arguments")
)
except ValueError:
argument_dict = {}

tool_selections.append(
ToolSelection(
tool_id=tool_call.get("id"),
tool_name=tool_call.get("function").get("name"),
tool_kwargs=argument_dict,
)
)

return tool_selections

def _get_response_token_counts(self, raw_response: Any) -> dict:
"""Get the token usage reported by the response."""
if isinstance(raw_response, dict):
usage = raw_response.get("usage", {})
if usage is None:
return {}

prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
else:
return {}

return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import urllib.parse
from typing import Dict, Union, Optional
from typing import Dict, Union, Optional, List, Any


from llama_index.core.base.llms.generic_utils import (
get_from_param_or_env,
)
from llama_index.core.base.llms.types import ChatMessage, MessageRole

# Import SecretStr directly from pydantic
# since there is not one in llama_index.core.bridge.pydantic
Expand Down Expand Up @@ -109,3 +110,51 @@ def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
if isinstance(value, SecretStr):
return value
return SecretStr(value)


def to_watsonx_message_dict(message: ChatMessage) -> dict:
"""Convert generic message to message dict."""
message_dict = {
"role": message.role.value,
"content": message.content,
}

message_dict.update(message.additional_kwargs)

return message_dict


def from_watsonx_message(message: dict) -> ChatMessage:
"""Convert Watsonx message dict to generic message."""
role = message.get("role", MessageRole.ASSISTANT)
content = message.get("content")

additional_kwargs: Dict[str, Any] = {}
if message.get("tool_calls") is not None:
tool_calls: List[dict] = message.get("tool_calls")
additional_kwargs.update(tool_calls=tool_calls)

return ChatMessage(role=role, content=content, additional_kwargs=additional_kwargs)


def update_tool_calls(tool_calls: list, tool_calls_update: list):
"""Use the tool_calls_update objects received from stream chunks
to update the running tool_calls object.
"""
if tool_calls_update is None:
return tool_calls

tc_delta = tool_calls_update[0]

if len(tool_calls) == 0:
tool_calls.append(tc_delta)
else:
t = tool_calls[-1]
if t["index"] != tc_delta["index"]:
tool_calls.append(tc_delta)
else:
t["function"]["arguments"] += tc_delta["function"]["arguments"] or ""
t["function"]["name"] += tc_delta["function"]["name"] or ""
t["id"] += tc_delta.get("id", "")

return tool_calls
Loading

0 comments on commit f33ce66

Please sign in to comment.