Skip to content

Commit

Permalink
.Net: Update to latest M.E.AI (#9795)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Nov 25, 2024
1 parent 33c1de6 commit ada7ba6
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 172 deletions.
8 changes: 4 additions & 4 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@
<PackageVersion Include="System.Net.Http" Version="4.3.4" />
<PackageVersion Include="System.Numerics.Tensors" Version="8.0.0" />
<PackageVersion Include="System.Text.Json" Version="8.0.5" />
<PackageVersion Include="OllamaSharp" Version="4.0.6" />
<PackageVersion Include="OllamaSharp" Version="4.0.8" />
<!-- Tokenizers -->
<PackageVersion Include="Microsoft.ML.Tokenizers" Version="1.0.0" />
<PackageVersion Include="Microsoft.DeepDev.TokenizerLib" Version="1.3.3" />
<PackageVersion Include="SharpToken" Version="2.0.3" />
<!-- Microsoft.Extensions.* -->
<PackageVersion Include="Microsoft.Extensions.AI" Version="9.0.0-preview.9.24556.5" />
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24556.5" />
<PackageVersion Include="Microsoft.Extensions.AI.AzureAIInference" Version="9.0.0-preview.9.24556.5" />
<PackageVersion Include="Microsoft.Extensions.AI" Version="9.0.1-preview.1.24570.5" />
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.1-preview.1.24570.5" />
<PackageVersion Include="Microsoft.Extensions.AI.AzureAIInference" Version="9.0.1-preview.1.24570.5" />
<PackageVersion Include="Microsoft.Extensions.Configuration" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Configuration.Binder" Version="8.0.2" />
<PackageVersion Include="Microsoft.Extensions.Configuration.EnvironmentVariables" Version="8.0.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Azure;
using Azure.AI.Inference;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureAIInference;
Expand Down Expand Up @@ -51,7 +52,6 @@ public void ConstructorsWorksAsExpected()
{
// Arrange
using var httpClient = new HttpClient() { BaseAddress = this._endpoint };
var loggerFactoryMock = new Mock<ILoggerFactory>();
ChatCompletionsClient client = new(this._endpoint, new AzureKeyCredential("api-key"));

// Act & Assert
Expand All @@ -60,12 +60,12 @@ public void ConstructorsWorksAsExpected()
new AzureAIInferenceChatCompletionService(modelId: "model", httpClient: httpClient, apiKey: null); // Only the HttpClient with a BaseClass defined
new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null); // ModelId and endpoint
new AzureAIInferenceChatCompletionService(modelId: "model", apiKey: "api-key", endpoint: this._endpoint); // ModelId, apiKey, and endpoint
new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null, loggerFactory: loggerFactoryMock.Object); // Endpoint and loggerFactory
new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null, loggerFactory: NullLoggerFactory.Instance); // Endpoint and loggerFactory

// Breaking Glass constructor
new AzureAIInferenceChatCompletionService(modelId: null, chatClient: client); // Client without model
new AzureAIInferenceChatCompletionService(modelId: "model", chatClient: client); // Client
new AzureAIInferenceChatCompletionService(modelId: "model", chatClient: client, loggerFactory: loggerFactoryMock.Object); // Client
new AzureAIInferenceChatCompletionService(modelId: "model", chatClient: client, loggerFactory: NullLoggerFactory.Instance); // Client
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Http;

namespace Microsoft.SemanticKernel;

Expand Down Expand Up @@ -38,34 +37,30 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(
{
Verify.NotNull(services);

services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
return services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
{
var chatClientBuilder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);
var options = new AzureAIInferenceClientOptions();
var logger = serviceProvider.GetService<ILoggerFactory>()?.CreateLogger<ChatCompletionsClient>();
if (logger is not null)
httpClient ??= serviceProvider.GetService<HttpClient>();
if (httpClient is not null)
{
chatClientBuilder.UseLogging(logger);
options.Transport = new HttpClientTransport(httpClient);
}
var options = new AzureAIInferenceClientOptions();
if (httpClient is not null)
var loggerFactory = serviceProvider.GetService<ILoggerFactory>();
var builder = new Azure.AI.Inference.ChatCompletionsClient(endpoint, new Azure.AzureKeyCredential(apiKey ?? SingleSpace), options)
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes);
if (loggerFactory is not null)
{
options.Transport = new HttpClientTransport(HttpClientProvider.GetHttpClient(httpClient, serviceProvider));
builder.UseLogging(loggerFactory);
}
return
chatClientBuilder.Use(
new Microsoft.Extensions.AI.AzureAIInferenceChatClient(
modelId: modelId,
chatCompletionsClient: new Azure.AI.Inference.ChatCompletionsClient(endpoint, new Azure.AzureKeyCredential(apiKey ?? SingleSpace), options)
)
).AsChatCompletionService();
return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider);
});

return services;
}

/// <summary>
Expand All @@ -88,34 +83,30 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(
{
Verify.NotNull(services);

services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
return services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
{
var chatClientBuilder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);
var options = new AzureAIInferenceClientOptions();
var logger = serviceProvider.GetService<ILoggerFactory>()?.CreateLogger<ChatCompletionsClient>();
if (logger is not null)
httpClient ??= serviceProvider.GetService<HttpClient>();
if (httpClient is not null)
{
chatClientBuilder.UseLogging(logger);
options.Transport = new HttpClientTransport(httpClient);
}
var options = new AzureAIInferenceClientOptions();
if (httpClient is not null)
var loggerFactory = serviceProvider.GetService<ILoggerFactory>();
var builder = new Azure.AI.Inference.ChatCompletionsClient(endpoint, credential, options)
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes);
if (loggerFactory is not null)
{
options.Transport = new HttpClientTransport(HttpClientProvider.GetHttpClient(httpClient, serviceProvider));
builder.UseLogging(loggerFactory);
}
return
chatClientBuilder.Use(
new Microsoft.Extensions.AI.AzureAIInferenceChatClient(
modelId: modelId,
chatCompletionsClient: new Azure.AI.Inference.ChatCompletionsClient(endpoint, credential, options)
)
).AsChatCompletionService();
return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider);
});

return services;
}

/// <summary>
Expand All @@ -133,26 +124,24 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(this IService
{
Verify.NotNull(services);

services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
return services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
{
chatClient ??= serviceProvider.GetRequiredService<ChatCompletionsClient>();
var chatClientBuilder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);
var loggerFactory = serviceProvider.GetService<ILoggerFactory>();
var logger = serviceProvider.GetService<ILoggerFactory>()?.CreateLogger<ChatCompletionsClient>();
if (logger is not null)
var builder = chatClient
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes);
if (loggerFactory is not null)
{
chatClientBuilder.UseLogging(logger);
builder.UseLogging(loggerFactory);
}
return chatClientBuilder
.Use(new Microsoft.Extensions.AI.AzureAIInferenceChatClient(chatClient, modelId))
.AsChatCompletionService();
return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider);
});

return services;
}

/// <summary>
Expand All @@ -168,26 +157,23 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(this IService
{
Verify.NotNull(services);

services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
return services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
{
chatClient ??= serviceProvider.GetRequiredService<AzureAIInferenceChatClient>();
var chatClientBuilder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);
var loggerFactory = serviceProvider.GetService<ILoggerFactory>();
var logger = serviceProvider.GetService<ILoggerFactory>()?.CreateLogger<ChatCompletionsClient>();
if (logger is not null)
var builder = chatClient
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes);
if (loggerFactory is not null)
{
chatClientBuilder.UseLogging(logger);
builder.UseLogging(loggerFactory);
}
return chatClientBuilder
.Use(chatClient)
.AsChatCompletionService();
return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider);
});

return services;
}

#region Private
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,19 @@ public AzureAIInferenceChatCompletionService(
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService));
this._core = new(
modelId,
apiKey,
endpoint,
httpClient,
logger);

var builder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

if (logger is not null)
this._core = new ChatClientCore(modelId, apiKey, endpoint, httpClient);

var builder = this._core.Client
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

if (loggerFactory is not null)
{
builder = builder.UseLogging(logger);
builder.UseLogging(loggerFactory);
}

this._chatService = builder
.Use(this._core.Client.AsChatClient(modelId))
.AsChatCompletionService();
this._chatService = builder.Build().AsChatCompletionService();
}

/// <summary>
Expand All @@ -75,26 +68,19 @@ public AzureAIInferenceChatCompletionService(
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService));
this._core = new(
modelId,
credential,
endpoint,
httpClient,
logger);

var builder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

if (logger is not null)
this._core = new ChatClientCore(modelId, credential, endpoint, httpClient);

var builder = this._core.Client
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

if (loggerFactory is not null)
{
builder = builder.UseLogging(logger);
builder.UseLogging(loggerFactory);
}

this._chatService = builder
.Use(this._core.Client.AsChatClient(modelId))
.AsChatCompletionService();
this._chatService = builder.Build().AsChatCompletionService();
}

/// <summary>
Expand All @@ -108,24 +94,21 @@ public AzureAIInferenceChatCompletionService(
ChatCompletionsClient chatClient,
ILoggerFactory? loggerFactory = null)
{
var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService));
this._core = new(
modelId,
chatClient,
logger);
Verify.NotNull(chatClient);

this._core = new ChatClientCore(modelId, chatClient);

var builder = new ChatClientBuilder()
.UseFunctionInvocation(config =>
config.MaximumIterationsPerRequest = MaxInflightAutoInvokes);
var builder = chatClient
.AsChatClient(modelId)
.AsBuilder()
.UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes);

if (logger is not null)
if (loggerFactory is not null)
{
builder = builder.UseLogging(logger);
builder.UseLogging(loggerFactory);
}

this._chatService = builder
.Use(this._core.Client.AsChatClient(modelId))
.AsChatCompletionService();
this._chatService = builder.Build().AsChatCompletionService();
}

/// <inheritdoc/>
Expand Down
Loading

0 comments on commit ada7ba6

Please sign in to comment.