Skip to content

Commit

Permalink
Python: graduate filters, add exception during addition and some clea…
Browse files Browse the repository at this point in the history
…nup (#9856)

### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
This PR graduates the filters.

This includes some updates to the docstrings.
And adds a specific Exception for errors during adding or removing of
filters.

Closes #9838 
Fixes #9641 

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
eavanvalkenburg authored Dec 3, 2024
1 parent 560e4c9 commit e0042af
Show file tree
Hide file tree
Showing 24 changed files with 179 additions and 90 deletions.
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ environments = [
]

[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = [
Expand Down
15 changes: 4 additions & 11 deletions python/samples/concepts/filtering/auto_function_invoke_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,12 @@
import os

from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAIChatPromptExecutionSettings
from semantic_kernel.contents import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents import ChatHistory, ChatMessageContent, FunctionCallContent, FunctionResultContent
from semantic_kernel.core_plugins import MathPlugin, TimePlugin
from semantic_kernel.filters.auto_function_invocation.auto_function_invocation_context import (
AutoFunctionInvocationContext,
)
from semantic_kernel.filters.filter_types import FilterTypes
from semantic_kernel.functions import KernelArguments
from semantic_kernel.functions.function_result import FunctionResult
from semantic_kernel.filters import AutoFunctionInvocationContext, FilterTypes
from semantic_kernel.functions import FunctionResult, KernelArguments

system_message = """
You are a chat bot. Your name is Mosscap and
Expand Down
11 changes: 5 additions & 6 deletions python/samples/concepts/filtering/function_invocation_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from collections.abc import Callable, Coroutine
from typing import Any

from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.exceptions.kernel_exceptions import OperationCancelledException
from semantic_kernel.filters.filter_types import FilterTypes
from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext
from semantic_kernel.kernel import Kernel
from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion
from semantic_kernel.contents import ChatHistory
from semantic_kernel.exceptions import OperationCancelledException
from semantic_kernel.filters import FilterTypes, FunctionInvocationContext

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import asyncio
import logging
import os
from collections.abc import Callable, Coroutine
from functools import reduce
from typing import Any

from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion
from semantic_kernel.contents import AuthorRole
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.filters.filter_types import FilterTypes
from semantic_kernel.functions.function_result import FunctionResult
from semantic_kernel.kernel import Kernel
from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion
from semantic_kernel.contents import AuthorRole, ChatHistory, StreamingChatMessageContent
from semantic_kernel.filters import FilterTypes, FunctionInvocationContext
from semantic_kernel.functions import FunctionResult

logger = logging.getLogger(__name__)

Expand All @@ -32,15 +32,20 @@
# in the specific case of a filter for streaming functions, you need to override the generator
# that is present in the function_result.value as seen below.
@kernel.filter(FilterTypes.FUNCTION_INVOCATION)
async def streaming_exception_handling(context, next):
async def streaming_exception_handling(
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Coroutine[Any, Any, None]],
):
await next(context)

async def override_stream(stream):
try:
async for partial in stream:
yield partial
except Exception as e:
yield [StreamingChatMessageContent(role=AuthorRole.ASSISTANT, content=f"Exception caught: {e}")]
yield [
StreamingChatMessageContent(role=AuthorRole.ASSISTANT, content=f"Exception caught: {e}", choice_index=0)
]

stream = context.result.value
context.result = FunctionResult(function=context.result.function, value=override_stream(stream))
Expand Down
3 changes: 2 additions & 1 deletion python/semantic_kernel/connectors/ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings

__all__ = ["PromptExecutionSettings"]
__all__ = ["FunctionChoiceBehavior", "PromptExecutionSettings"]
2 changes: 1 addition & 1 deletion python/semantic_kernel/contents/chat_message_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ChatMessageContent(KernelContent):
tag: ClassVar[str] = CHAT_MESSAGE_CONTENT_TAG
role: AuthorRole
name: str | None = None
items: list[Annotated[ITEM_TYPES, Field(..., discriminator=DISCRIMINATOR_FIELD)]] = Field(default_factory=list)
items: list[Annotated[ITEM_TYPES, Field(discriminator=DISCRIMINATOR_FIELD)]] = Field(default_factory=list)
encoding: str | None = None
finish_reason: FinishReason | None = None

Expand Down
1 change: 1 addition & 0 deletions python/semantic_kernel/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from semantic_kernel.exceptions.agent_exceptions import * # noqa: F403
from semantic_kernel.exceptions.content_exceptions import * # noqa: F403
from semantic_kernel.exceptions.filter_exceptions import * # noqa: F403
from semantic_kernel.exceptions.function_exceptions import * # noqa: F403
from semantic_kernel.exceptions.kernel_exceptions import * # noqa: F403
from semantic_kernel.exceptions.memory_connector_exceptions import * # noqa: F403
Expand Down
20 changes: 20 additions & 0 deletions python/semantic_kernel/exceptions/filter_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Microsoft. All rights reserved.
from semantic_kernel.exceptions.kernel_exceptions import KernelException


class FilterException(KernelException):
"""Base class for all filter exceptions."""

pass


class FilterManagementException(FilterException):
"""An error occurred while adding or removing the filter to/from the kernel."""

pass


__all__ = [
"FilterException",
"FilterManagementException",
]
15 changes: 15 additions & 0 deletions python/semantic_kernel/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Microsoft. All rights reserved.

from semantic_kernel.filters.auto_function_invocation.auto_function_invocation_context import (
AutoFunctionInvocationContext,
)
from semantic_kernel.filters.filter_types import FilterTypes
from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext
from semantic_kernel.filters.prompts.prompt_render_context import PromptRenderContext

__all__ = [
"AutoFunctionInvocationContext",
"FilterTypes",
"FunctionInvocationContext",
"PromptRenderContext",
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,27 @@


class AutoFunctionInvocationContext(FilterContextBase):
"""Class for auto function invocation context."""
"""Class for auto function invocation context.
This is the context supplied to the auto function invocation filters.
Common use case are to alter the function_result, for instance filling it with a pre-computed
value, in order to skip a step, for instance when doing caching.
Another option is to terminate, this can be done by setting terminate to True.
Attributes:
function: The function invoked.
kernel: The kernel used.
arguments: The arguments used to call the function.
chat_history: The chat history or None.
function_result: The function result or None.
request_sequence_index: The request sequence index.
function_sequence_index: The function sequence index.
function_count: The function count.
terminate: The flag to terminate.
"""

chat_history: "ChatHistory | None" = None
function_result: "FunctionResult | None" = None
Expand Down
2 changes: 0 additions & 2 deletions python/semantic_kernel/filters/filter_context_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
from typing import TYPE_CHECKING

from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.experimental_decorator import experimental_class

if TYPE_CHECKING:
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.functions.kernel_function import KernelFunction
from semantic_kernel.kernel import Kernel


@experimental_class
class FilterContextBase(KernelBaseModel):
"""Base class for Kernel Filter Contexts."""

Expand Down
3 changes: 0 additions & 3 deletions python/semantic_kernel/filters/filter_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

from enum import Enum

from semantic_kernel.utils.experimental_decorator import experimental_class


@experimental_class
class FilterTypes(str, Enum):
"""Enum for the filter types."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@


class FunctionInvocationContext(FilterContextBase):
"""Class for function invocation context."""
"""Class for function invocation context.
This filter can be used to monitor which functions are called.
To log what function was called with which parameters and what output.
Finally it can be used for caching by setting the result value.
Attributes:
function: The function invoked.
kernel: The kernel used.
arguments: The arguments used to call the function.
result: The result of the function, or None.
"""

result: "FunctionResult | None" = None
21 changes: 12 additions & 9 deletions python/semantic_kernel/filters/kernel_filters_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from pydantic import Field

from semantic_kernel.exceptions.filter_exceptions import FilterManagementException
from semantic_kernel.filters.filter_context_base import FilterContextBase
from semantic_kernel.filters.filter_types import FilterTypes
from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.experimental_decorator import experimental_function

FILTER_CONTEXT_TYPE = TypeVar("FILTER_CONTEXT_TYPE", bound=FilterContextBase)
CALLABLE_FILTER_TYPE = Callable[[FILTER_CONTEXT_TYPE, Callable[[FILTER_CONTEXT_TYPE], None]], None]
Expand All @@ -32,7 +32,6 @@ class KernelFilterExtension(KernelBaseModel, ABC):
prompt_rendering_filters: list[tuple[int, CALLABLE_FILTER_TYPE]] = Field(default_factory=list)
auto_function_invocation_filters: list[tuple[int, CALLABLE_FILTER_TYPE]] = Field(default_factory=list)

@experimental_function
def add_filter(self, filter_type: ALLOWED_FILTERS_LITERAL | FilterTypes, filter: CALLABLE_FILTER_TYPE) -> None:
"""Add a filter to the Kernel.
Expand All @@ -45,12 +44,17 @@ def add_filter(self, filter_type: ALLOWED_FILTERS_LITERAL | FilterTypes, filter:
filter_type (str): The type of the filter to add (function_invocation, prompt_rendering)
filter (object): The filter to add
Raises:
FilterDefinitionException: If an error occurs while adding the filter to the kernel
"""
if not isinstance(filter_type, FilterTypes):
filter_type = FilterTypes(filter_type)
getattr(self, FILTER_MAPPING[filter_type.value]).insert(0, (id(filter), filter))
try:
if not isinstance(filter_type, FilterTypes):
filter_type = FilterTypes(filter_type)
getattr(self, FILTER_MAPPING[filter_type.value]).insert(0, (id(filter), filter))
except Exception as ecx:
raise FilterManagementException(f"Error adding filter {filter} to {filter_type}") from ecx

@experimental_function
def filter(
self, filter_type: ALLOWED_FILTERS_LITERAL | FilterTypes
) -> Callable[[CALLABLE_FILTER_TYPE], CALLABLE_FILTER_TYPE]:
Expand All @@ -64,7 +68,6 @@ def decorator(

return decorator

@experimental_function
def remove_filter(
self,
filter_type: ALLOWED_FILTERS_LITERAL | FilterTypes | None = None,
Expand All @@ -83,10 +86,10 @@ def remove_filter(
if filter_type and not isinstance(filter_type, FilterTypes):
filter_type = FilterTypes(filter_type)
if filter_id is None and position is None:
raise ValueError("Either hook_id or position should be provided.")
raise FilterManagementException("Either hook_id or position should be provided.")
if position is not None:
if filter_type is None:
raise ValueError("Please specify the type of filter when using position.")
raise FilterManagementException("Please specify the type of filter when using position.")
getattr(self, FILTER_MAPPING[filter_type]).pop(position)
return
if filter_type:
Expand Down
14 changes: 13 additions & 1 deletion python/semantic_kernel/filters/prompts/prompt_render_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,19 @@


class PromptRenderContext(FilterContextBase):
"""Context for prompt rendering filters."""
"""Context for prompt rendering filters.
When prompt rendering is expensive (for instance when there are expensive functions being called.)
This filter can be used to set the rendered_prompt directly and returning.
Attributes:
function: The function invoked.
kernel: The kernel used.
arguments: The arguments used to call the function.
rendered_prompt: The result of the prompt rendering.
function_result: The result of the function that used the prompt.
"""

rendered_prompt: str | None = None
function_result: "FunctionResult | None" = None
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# There is only the whisper model available on Azure OpenAI for audio to text. And that model is
# only available in the North Switzerland region. Therefore, the endpoint is different than the one
# we use for other services.
is_service_setup_for_testing(["AZURE_OPENAI_AUDIO_TO_TEXT_ENDPOINT"])
azure_setup = is_service_setup_for_testing(["AZURE_OPENAI_AUDIO_TO_TEXT_ENDPOINT"], raise_if_not_set=False)


class AudioToTextTestBase:
Expand All @@ -22,5 +22,7 @@ def services(self) -> dict[str, AudioToTextClientBase]:
"""Return audio-to-text services."""
return {
"openai": OpenAIAudioToText(),
"azure_openai": AzureAudioToText(endpoint=os.environ["AZURE_OPENAI_AUDIO_TO_TEXT_ENDPOINT"]),
"azure_openai": AzureAudioToText(endpoint=os.environ["AZURE_OPENAI_AUDIO_TO_TEXT_ENDPOINT"])
if azure_setup
else None,
}
2 changes: 2 additions & 0 deletions python/tests/integration/audio_to_text/test_audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ async def test_audio_to_text(
"""

service = services[service_id]
if not service:
pytest.mark.xfail("Azure Audio to Text not setup.")
result = await service.get_text_content(audio_content)

for word in expected_text:
Expand Down
26 changes: 15 additions & 11 deletions python/tests/integration/completions/chat_completion_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,21 @@
# There is no single model in Ollama that supports both image and tool call in chat completion
# We are splitting the Ollama test into three services: chat, image, and tool call. The chat model
# can be any model that supports chat completion. Also, Ollama is only available on Linux runners in our pipeline.
ollama_setup: bool = is_service_setup_for_testing(["OLLAMA_CHAT_MODEL_ID"]) and is_test_running_on_supported_platforms([
"Linux"
])
ollama_image_setup: bool = is_service_setup_for_testing([
"OLLAMA_CHAT_MODEL_ID_IMAGE"
]) and is_test_running_on_supported_platforms(["Linux"])
ollama_tool_call_setup: bool = is_service_setup_for_testing([
"OLLAMA_CHAT_MODEL_ID_TOOL_CALL"
]) and is_test_running_on_supported_platforms(["Linux"])
google_ai_setup: bool = is_service_setup_for_testing(["GOOGLE_AI_API_KEY", "GOOGLE_AI_GEMINI_MODEL_ID"])
vertex_ai_setup: bool = is_service_setup_for_testing(["VERTEX_AI_PROJECT_ID", "VERTEX_AI_GEMINI_MODEL_ID"])
ollama_setup: bool = is_service_setup_for_testing(
["OLLAMA_CHAT_MODEL_ID"], raise_if_not_set=False
) and is_test_running_on_supported_platforms(["Linux"])
ollama_image_setup: bool = is_service_setup_for_testing(
["OLLAMA_CHAT_MODEL_ID_IMAGE"], raise_if_not_set=False
) and is_test_running_on_supported_platforms(["Linux"])
ollama_tool_call_setup: bool = is_service_setup_for_testing(
["OLLAMA_CHAT_MODEL_ID_TOOL_CALL"], raise_if_not_set=False
) and is_test_running_on_supported_platforms(["Linux"])
google_ai_setup: bool = is_service_setup_for_testing(
["GOOGLE_AI_API_KEY", "GOOGLE_AI_GEMINI_MODEL_ID"], raise_if_not_set=False
)
vertex_ai_setup: bool = is_service_setup_for_testing(
["VERTEX_AI_PROJECT_ID", "VERTEX_AI_GEMINI_MODEL_ID"], raise_if_not_set=False
)
onnx_setup: bool = is_service_setup_for_testing(
["ONNX_GEN_AI_CHAT_MODEL_FOLDER"], raise_if_not_set=False
) # Tests are optional for ONNX
Expand Down
Loading

0 comments on commit e0042af

Please sign in to comment.