From aba9fcd3475d1876f9c87b016deb7a75e8e5ddb8 Mon Sep 17 00:00:00 2001 From: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> Date: Fri, 29 Nov 2024 19:38:26 +0000 Subject: [PATCH] Add missing execution settings --- .../Client/MistralClientTests.cs | 62 ++++++++++++- .../Client/ChatCompletionRequest.cs | 16 ++++ .../Client/MistralChatMessage.cs | 5 ++ .../Client/MistralClient.cs | 10 ++- .../MistralAIPromptExecutionSettings.cs | 89 +++++++++++++++++++ .../MistralAIChatCompletionTests.cs | 34 +++++++ 6 files changed, 212 insertions(+), 4 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs index d61696cec36d..fc00db9a8ea1 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs @@ -80,7 +80,7 @@ public async Task ValidateChatMessageRequestAsync() Assert.Equal(0.9, chatRequest.Temperature); Assert.Single(chatRequest.Messages); Assert.Equal("user", chatRequest.Messages[0].Role); - Assert.Equal("What is the best French cheese?", chatRequest.Messages[0].Content); + Assert.Equal("What is the best French cheese?", chatRequest.Messages[0].Content?.ToString()); } [Fact] @@ -522,6 +522,31 @@ public void ValidateToMistralChatMessages(string roleLabel, string content) Assert.Single(messages); } + [Fact] + public void ValidateToMistralChatMessagesWithMultipleContents() + { + // Arrange + using var httpClient = new HttpClient(); + var client = new MistralClient("mistral-large-latest", httpClient, "key"); + var chatMessage = new ChatMessageContent() + { + Role = AuthorRole.User, + Items = + [ + new TextContent("What is the weather like in Paris?"), + new ImageContent(new Uri("https://tripfixers.com/wp-content/uploads/2019/11/eiffel-tower-with-snow.jpeg")) + ], + }; + + // Act + var messages = client.ToMistralChatMessages(chatMessage, default); + + // Assert + Assert.NotNull(messages); + Assert.Single(messages); + Assert.IsType>(messages[0].Content); + } + [Fact] public void ValidateToMistralChatMessagesWithFunctionCallContent() { @@ -562,6 +587,41 @@ public void ValidateToMistralChatMessagesWithFunctionResultContent() Assert.Equal(2, messages.Count); } + [Fact] + public void ValidateCloneMistralAIPromptExecutionSettings() + { + // Arrange + var settings = new MistralAIPromptExecutionSettings + { + MaxTokens = 1024, + Temperature = 0.9, + TopP = 0.9, + FrequencyPenalty = 0.9, + PresencePenalty = 0.9, + Stop = ["stop"], + SafePrompt = true, + RandomSeed = 123, + ResponseFormat = new { format = "json" }, + }; + + // Act + var clonedSettings = settings.Clone(); + + // Assert + Assert.NotNull(clonedSettings); + Assert.IsType(clonedSettings); + var clonedMistralAISettings = clonedSettings as MistralAIPromptExecutionSettings; + Assert.Equal(settings.MaxTokens, clonedMistralAISettings!.MaxTokens); + Assert.Equal(settings.Temperature, clonedMistralAISettings.Temperature); + Assert.Equal(settings.TopP, clonedMistralAISettings.TopP); + Assert.Equal(settings.FrequencyPenalty, clonedMistralAISettings.FrequencyPenalty); + Assert.Equal(settings.PresencePenalty, clonedMistralAISettings.PresencePenalty); + Assert.Equal(settings.Stop, clonedMistralAISettings.Stop); + Assert.Equal(settings.SafePrompt, clonedMistralAISettings.SafePrompt); + Assert.Equal(settings.RandomSeed, clonedMistralAISettings.RandomSeed); + Assert.Equal(settings.ResponseFormat, clonedMistralAISettings.ResponseFormat); + } + public sealed class WeatherPlugin { [KernelFunction] diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs index e1fc8dbfe996..cf5a3258ea9a 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/ChatCompletionRequest.cs @@ -44,6 +44,22 @@ internal sealed class ChatCompletionRequest [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public int? RandomSeed { get; set; } + [JsonPropertyName("response_format")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? ResponseFormat { get; set; } + + [JsonPropertyName("frequency_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public double? FrequencyPenalty { get; set; } + + [JsonPropertyName("presence_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public double? PresencePenalty { get; set; } + + [JsonPropertyName("stop")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public IList? Stop { get; set; } + /// /// Construct an instance of . /// diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatMessage.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatMessage.cs index 3a7d385dc697..e587ac8f5c95 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatMessage.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralChatMessage.cs @@ -38,6 +38,11 @@ internal sealed class MistralChatMessage [JsonConstructor] internal MistralChatMessage(string? role, object? content) { + if (role is not null and not "system" and not "user" and not "assistant" and not "tool") + { + throw new System.ArgumentException($"Role must be one of: system, user, assistant or tool. {role} is an invalid role.", nameof(role)); + } + this.Role = role; this.Content = content; } diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs index 1981bcc1da6f..7f73c8ad8f79 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs @@ -693,7 +693,11 @@ private ChatCompletionRequest CreateChatCompletionRequest(string modelId, bool s TopP = executionSettings.TopP, MaxTokens = executionSettings.MaxTokens, SafePrompt = executionSettings.SafePrompt, - RandomSeed = executionSettings.RandomSeed + RandomSeed = executionSettings.RandomSeed, + ResponseFormat = executionSettings.ResponseFormat, + FrequencyPenalty = executionSettings.FrequencyPenalty, + PresencePenalty = executionSettings.PresencePenalty, + Stop = executionSettings.Stop, }; executionSettings.ToolCallBehavior?.ConfigureRequest(kernel, request); @@ -1016,8 +1020,8 @@ private static string ProcessFunctionResult(object functionResult, MistralAITool return stringResult; } - // This is an optimization to use ChatMessageContent chatMessage directly - // without unnecessary serialization of the whole message chatMessage class. + // This is an optimization to use ChatMessageContent content directly + // without unnecessary serialization of the whole message content class. if (functionResult is ChatMessageContent chatMessageContent) { return chatMessageContent.ToString(); diff --git a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs index 9e136d0e089f..6766ceb317c0 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs @@ -155,6 +155,87 @@ public MistralAIToolCallBehavior? ToolCallBehavior } } + /// + /// Gets or sets the response format to use for the completion. + /// + /// + /// An object specifying the format that the model must output. + /// Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON. + /// When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message. + /// + [JsonPropertyName("response_format")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? ResponseFormat + { + get => this._responseFormat; + + set + { + this.ThrowIfFrozen(); + this._responseFormat = value; + } + } + + /// + /// Gets or sets the stop sequences to use for the completion. + /// + /// + /// Stop generation if this token is detected. Or if one of these tokens is detected when providing an array + /// + [JsonPropertyName("stop")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public IList? Stop + { + get => this._stop; + + set + { + this.ThrowIfFrozen(); + this._stop = value; + } + } + + + /// + /// Number between -2.0 and 2.0. Positive values penalize new tokens + /// based on whether they appear in the text so far, increasing the + /// model's likelihood to talk about new topics. + /// + /// + /// presence_penalty determines how much the model penalizes the repetition of words or phrases. + /// A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative. + /// + [JsonPropertyName("presence_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public double? PresencePenalty + { + get => this._presencePenalty; + + set + { + this.ThrowIfFrozen(); + this._presencePenalty = value; + } + } + + /// + /// Number between -2.0 and 2.0. Positive values penalize new tokens + /// based on their existing frequency in the text so far, decreasing + /// the model's likelihood to repeat the same line verbatim. + /// + [JsonPropertyName("frequency_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public double? FrequencyPenalty + { + get => this._frequencyPenalty; + + set + { + this.ThrowIfFrozen(); + this._frequencyPenalty = value; + } + } + /// public override void Freeze() { @@ -180,6 +261,10 @@ public override PromptExecutionSettings Clone() RandomSeed = this.RandomSeed, ApiVersion = this.ApiVersion, ToolCallBehavior = this.ToolCallBehavior, + ResponseFormat = this.ResponseFormat, + FrequencyPenalty = this.FrequencyPenalty, + PresencePenalty = this.PresencePenalty, + Stop = this.Stop, }; } @@ -215,6 +300,10 @@ public static MistralAIPromptExecutionSettings FromExecutionSettings(PromptExecu private int? _randomSeed; private string _apiVersion = "v1"; private MistralAIToolCallBehavior? _toolCallBehavior; + private object? _responseFormat; + private double? _presencePenalty; + private double? _frequencyPenalty; + private IList? _stop; #endregion } diff --git a/dotnet/src/IntegrationTests/Connectors/MistralAI/ChatCompletion/MistralAIChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/MistralAI/ChatCompletion/MistralAIChatCompletionTests.cs index 62d23f5a0517..d70dee87442d 100644 --- a/dotnet/src/IntegrationTests/Connectors/MistralAI/ChatCompletion/MistralAIChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/MistralAI/ChatCompletion/MistralAIChatCompletionTests.cs @@ -147,6 +147,40 @@ public async Task ValidateGetChatMessageContentsWithImageAsync() Assert.Contains("Snow", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase); } + [Fact] + public async Task ValidateGetChatMessageContentsWithImageAndJsonFormatAsync() + { + // Arrange + var model = this._configuration["MistralAI:ImageModelId"]; + var apiKey = this._configuration["MistralAI:ApiKey"]; + var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient); + + // Act + var systemMessage = "Return the answer in a JSON object with the next structure: " + + "{\"elements\": [{\"element\": \"some name of element1\", " + + "\"description\": \"some description of element 1\"}, " + + "{\"element\": \"some name of element2\", \"description\": " + + "\"some description of element 2\"}]}"; + var chatHistory = new ChatHistory(systemMessage) + { + new ChatMessageContent(AuthorRole.User, "Describe the image"), + new ChatMessageContent(AuthorRole.User, [new ImageContent(new Uri("https://tripfixers.com/wp-content/uploads/2019/11/eiffel-tower-with-snow.jpeg"))]) + }; + var executionSettings = new MistralAIPromptExecutionSettings + { + MaxTokens = 500, + ResponseFormat = new { type = "json_object" }, + }; + var response = await service.GetChatMessageContentsAsync(chatHistory, executionSettings); + + // Assert + Assert.NotNull(response); + Assert.Single(response); + Assert.Contains("Paris", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Eiffel Tower", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Snow", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase); + } + [Fact(Skip = "This test is for manual verification.")] public async Task ValidateInvokeChatPromptAsync() {