Skip to content

Commit

Permalink
Add missing execution settings
Browse files Browse the repository at this point in the history
  • Loading branch information
markwallace-microsoft committed Nov 29, 2024
1 parent 34a2aa7 commit aba9fcd
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public async Task ValidateChatMessageRequestAsync()
Assert.Equal(0.9, chatRequest.Temperature);
Assert.Single(chatRequest.Messages);
Assert.Equal("user", chatRequest.Messages[0].Role);
Assert.Equal("What is the best French cheese?", chatRequest.Messages[0].Content);
Assert.Equal("What is the best French cheese?", chatRequest.Messages[0].Content?.ToString());
}

[Fact]
Expand Down Expand Up @@ -522,6 +522,31 @@ public void ValidateToMistralChatMessages(string roleLabel, string content)
Assert.Single(messages);
}

[Fact]
public void ValidateToMistralChatMessagesWithMultipleContents()
{
// Arrange
using var httpClient = new HttpClient();
var client = new MistralClient("mistral-large-latest", httpClient, "key");
var chatMessage = new ChatMessageContent()
{
Role = AuthorRole.User,
Items =
[
new TextContent("What is the weather like in Paris?"),
new ImageContent(new Uri("https://tripfixers.com/wp-content/uploads/2019/11/eiffel-tower-with-snow.jpeg"))
],
};

// Act
var messages = client.ToMistralChatMessages(chatMessage, default);

// Assert
Assert.NotNull(messages);
Assert.Single(messages);
Assert.IsType<List<ContentChunk>>(messages[0].Content);
}

[Fact]
public void ValidateToMistralChatMessagesWithFunctionCallContent()
{
Expand Down Expand Up @@ -562,6 +587,41 @@ public void ValidateToMistralChatMessagesWithFunctionResultContent()
Assert.Equal(2, messages.Count);
}

[Fact]
public void ValidateCloneMistralAIPromptExecutionSettings()
{
// Arrange
var settings = new MistralAIPromptExecutionSettings
{
MaxTokens = 1024,
Temperature = 0.9,
TopP = 0.9,
FrequencyPenalty = 0.9,
PresencePenalty = 0.9,
Stop = ["stop"],
SafePrompt = true,
RandomSeed = 123,
ResponseFormat = new { format = "json" },
};

// Act
var clonedSettings = settings.Clone();

// Assert
Assert.NotNull(clonedSettings);
Assert.IsType<MistralAIPromptExecutionSettings>(clonedSettings);
var clonedMistralAISettings = clonedSettings as MistralAIPromptExecutionSettings;
Assert.Equal(settings.MaxTokens, clonedMistralAISettings!.MaxTokens);
Assert.Equal(settings.Temperature, clonedMistralAISettings.Temperature);
Assert.Equal(settings.TopP, clonedMistralAISettings.TopP);
Assert.Equal(settings.FrequencyPenalty, clonedMistralAISettings.FrequencyPenalty);
Assert.Equal(settings.PresencePenalty, clonedMistralAISettings.PresencePenalty);
Assert.Equal(settings.Stop, clonedMistralAISettings.Stop);
Assert.Equal(settings.SafePrompt, clonedMistralAISettings.SafePrompt);
Assert.Equal(settings.RandomSeed, clonedMistralAISettings.RandomSeed);
Assert.Equal(settings.ResponseFormat, clonedMistralAISettings.ResponseFormat);
}

public sealed class WeatherPlugin
{
[KernelFunction]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ internal sealed class ChatCompletionRequest
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? RandomSeed { get; set; }

[JsonPropertyName("response_format")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public object? ResponseFormat { get; set; }

[JsonPropertyName("frequency_penalty")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public double? FrequencyPenalty { get; set; }

[JsonPropertyName("presence_penalty")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public double? PresencePenalty { get; set; }

[JsonPropertyName("stop")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IList<string>? Stop { get; set; }

/// <summary>
/// Construct an instance of <see cref="ChatCompletionRequest"/>.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ internal sealed class MistralChatMessage
[JsonConstructor]
internal MistralChatMessage(string? role, object? content)
{
if (role is not null and not "system" and not "user" and not "assistant" and not "tool")
{
throw new System.ArgumentException($"Role must be one of: system, user, assistant or tool. {role} is an invalid role.", nameof(role));
}

this.Role = role;
this.Content = content;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,11 @@ private ChatCompletionRequest CreateChatCompletionRequest(string modelId, bool s
TopP = executionSettings.TopP,
MaxTokens = executionSettings.MaxTokens,
SafePrompt = executionSettings.SafePrompt,
RandomSeed = executionSettings.RandomSeed
RandomSeed = executionSettings.RandomSeed,
ResponseFormat = executionSettings.ResponseFormat,
FrequencyPenalty = executionSettings.FrequencyPenalty,
PresencePenalty = executionSettings.PresencePenalty,
Stop = executionSettings.Stop,
};

executionSettings.ToolCallBehavior?.ConfigureRequest(kernel, request);
Expand Down Expand Up @@ -1016,8 +1020,8 @@ private static string ProcessFunctionResult(object functionResult, MistralAITool
return stringResult;
}

// This is an optimization to use ChatMessageContent chatMessage directly
// without unnecessary serialization of the whole message chatMessage class.
// This is an optimization to use ChatMessageContent content directly
// without unnecessary serialization of the whole message content class.
if (functionResult is ChatMessageContent chatMessageContent)
{
return chatMessageContent.ToString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,87 @@ public MistralAIToolCallBehavior? ToolCallBehavior
}
}

/// <summary>
/// Gets or sets the response format to use for the completion.
/// </summary>
/// <remarks>
/// An object specifying the format that the model must output.
/// Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is in JSON.
/// When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message.
/// </remarks>
[JsonPropertyName("response_format")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public object? ResponseFormat
{
get => this._responseFormat;

set
{
this.ThrowIfFrozen();
this._responseFormat = value;
}
}

/// <summary>
/// Gets or sets the stop sequences to use for the completion.
/// </summary>
/// <remarks>
/// Stop generation if this token is detected. Or if one of these tokens is detected when providing an array
/// </remarks>
[JsonPropertyName("stop")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IList<string>? Stop
{
get => this._stop;

set
{
this.ThrowIfFrozen();
this._stop = value;
}
}


Check failure on line 198 in dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, ubuntu-latest, Release, true, integration)

Remove unnecessary blank line (https://josefpihrt.github.io/docs/roslynator/analyzers/RCS1036)

Check failure on line 198 in dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, ubuntu-latest, Release, true, integration)

Remove unnecessary blank line (https://josefpihrt.github.io/docs/roslynator/analyzers/RCS1036)

Check failure on line 198 in dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, ubuntu-latest, Release, true, integration)

Remove unnecessary blank line (https://josefpihrt.github.io/docs/roslynator/analyzers/RCS1036)

Check failure on line 198 in dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, ubuntu-latest, Release, true, integration)

Remove unnecessary blank line (https://josefpihrt.github.io/docs/roslynator/analyzers/RCS1036)

Check failure on line 198 in dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, windows-latest, Debug)

Remove unnecessary blank line (https://josefpihrt.github.io/docs/roslynator/analyzers/RCS1036)

Check failure on line 198 in dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, windows-latest, Debug)

Remove unnecessary blank line (https://josefpihrt.github.io/docs/roslynator/analyzers/RCS1036)

Check failure on line 198 in dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, windows-latest, Release)

Remove unnecessary blank line (https://josefpihrt.github.io/docs/roslynator/analyzers/RCS1036)

Check failure on line 198 in dotnet/src/Connectors/Connectors.MistralAI/MistralAIPromptExecutionSettings.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, windows-latest, Release)

Remove unnecessary blank line (https://josefpihrt.github.io/docs/roslynator/analyzers/RCS1036)
/// <summary>
/// Number between -2.0 and 2.0. Positive values penalize new tokens
/// based on whether they appear in the text so far, increasing the
/// model's likelihood to talk about new topics.
/// </summary>
/// <remarks>
/// presence_penalty determines how much the model penalizes the repetition of words or phrases.
/// A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative.
/// </remarks>
[JsonPropertyName("presence_penalty")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public double? PresencePenalty
{
get => this._presencePenalty;

set
{
this.ThrowIfFrozen();
this._presencePenalty = value;
}
}

/// <summary>
/// Number between -2.0 and 2.0. Positive values penalize new tokens
/// based on their existing frequency in the text so far, decreasing
/// the model's likelihood to repeat the same line verbatim.
/// </summary>
[JsonPropertyName("frequency_penalty")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public double? FrequencyPenalty
{
get => this._frequencyPenalty;

set
{
this.ThrowIfFrozen();
this._frequencyPenalty = value;
}
}

/// <inheritdoc/>
public override void Freeze()
{
Expand All @@ -180,6 +261,10 @@ public override PromptExecutionSettings Clone()
RandomSeed = this.RandomSeed,
ApiVersion = this.ApiVersion,
ToolCallBehavior = this.ToolCallBehavior,
ResponseFormat = this.ResponseFormat,
FrequencyPenalty = this.FrequencyPenalty,
PresencePenalty = this.PresencePenalty,
Stop = this.Stop,
};
}

Expand Down Expand Up @@ -215,6 +300,10 @@ public static MistralAIPromptExecutionSettings FromExecutionSettings(PromptExecu
private int? _randomSeed;
private string _apiVersion = "v1";
private MistralAIToolCallBehavior? _toolCallBehavior;
private object? _responseFormat;
private double? _presencePenalty;
private double? _frequencyPenalty;
private IList<string>? _stop;

#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,40 @@ public async Task ValidateGetChatMessageContentsWithImageAsync()
Assert.Contains("Snow", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase);
}

[Fact]
public async Task ValidateGetChatMessageContentsWithImageAndJsonFormatAsync()
{
// Arrange
var model = this._configuration["MistralAI:ImageModelId"];
var apiKey = this._configuration["MistralAI:ApiKey"];
var service = new MistralAIChatCompletionService(model!, apiKey!, httpClient: this._httpClient);

// Act
var systemMessage = "Return the answer in a JSON object with the next structure: " +
"{\"elements\": [{\"element\": \"some name of element1\", " +
"\"description\": \"some description of element 1\"}, " +
"{\"element\": \"some name of element2\", \"description\": " +
"\"some description of element 2\"}]}";
var chatHistory = new ChatHistory(systemMessage)
{
new ChatMessageContent(AuthorRole.User, "Describe the image"),
new ChatMessageContent(AuthorRole.User, [new ImageContent(new Uri("https://tripfixers.com/wp-content/uploads/2019/11/eiffel-tower-with-snow.jpeg"))])
};
var executionSettings = new MistralAIPromptExecutionSettings
{
MaxTokens = 500,
ResponseFormat = new { type = "json_object" },
};
var response = await service.GetChatMessageContentsAsync(chatHistory, executionSettings);

// Assert
Assert.NotNull(response);
Assert.Single(response);
Assert.Contains("Paris", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase);
Assert.Contains("Eiffel Tower", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase);
Assert.Contains("Snow", response[0].Content, System.StringComparison.InvariantCultureIgnoreCase);
}

[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateInvokeChatPromptAsync()
{
Expand Down

0 comments on commit aba9fcd

Please sign in to comment.