diff --git a/extensions/AWSSDK.Extensions.sln b/extensions/AWSSDK.Extensions.sln index 77675d2e7b1e..8a44ac3f73fb 100644 --- a/extensions/AWSSDK.Extensions.sln +++ b/extensions/AWSSDK.Extensions.sln @@ -36,11 +36,17 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CloudFront.Signers.Tests.Ne EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "EC2.DecryptPassword.NetStandard", "test\EC2.DecryptPasswordTests\EC2.DecryptPassword.NetStandard.csproj", "{EA6EEC77-E69B-4D42-B9F2-BADCEEE5A32B}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.CloudFront.NetStandard", "..\sdk\src\Services\CloudFront\AWSSDK.CloudFront.NetStandard.csproj", "{280223DF-ECB0-4B38-A3A6-B80B46D48475}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "BedrockMEAITests.NetFramework", "test\BedrockMEAITests\BedrockMEAITests.NetFramework.csproj", "{D98D6380-80A3-4818-84B4-3BD332383CA2}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.BedrockRuntime.NetStandard", "..\sdk\src\Services\BedrockRuntime\AWSSDK.BedrockRuntime.NetStandard.csproj", "{280223DF-ECB0-4B38-A3A6-B80B46D48475}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.CloudFront.NetStandard", "..\sdk\src\Services\CloudFront\AWSSDK.CloudFront.NetStandard.csproj", "{71C8FC92-F868-4E07-B005-62180C1D6B8B}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.EC2.NetStandard", "..\sdk\src\Services\EC2\AWSSDK.EC2.NetStandard.csproj", "{FC70CF98-BA7E-4F9F-A5DB-966973284091}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.CloudFront.NetFramework", "..\sdk\src\Services\CloudFront\AWSSDK.CloudFront.NetFramework.csproj", "{4FFF9872-1D77-4664-83C6-B46AC6EB1E20}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.BedrockRuntime.NetFramework", "..\sdk\src\Services\BedrockRuntime\AWSSDK.BedrockRuntime.NetFramework.csproj", "{4FFF9872-1D77-4664-83C6-B46AC6EB1E20}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.CloudFront.NetFramework", "..\sdk\src\Services\CloudFront\AWSSDK.CloudFront.NetFramework.csproj", "{B416F870-421E-410A-8848-13A7F523E669}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.EC2.NetFramework", "..\sdk\src\Services\EC2\AWSSDK.EC2.NetFramework.csproj", "{0377B228-91F3-4A0B-BE66-221E7ECA6DF7}" EndProject @@ -48,6 +54,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.Extensions.CloudFron EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.Extensions.EC2.DecryptPassword.NetFramework", "src\AWSSDK.Extensions.EC2.DecryptPassword\AWSSDK.Extensions.EC2.DecryptPassword.NetFramework.csproj", "{3EC669E6-A541-445E-B68E-0A853715E39C}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.Extensions.Bedrock.MEAI.NetFramework", "src\AWSSDK.Extensions.Bedrock.MEAI\AWSSDK.Extensions.Bedrock.MEAI.NetFramework.csproj", "{4A94F623-0C71-47BD-B927-CB6FA28D33A1}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.Extensions.Bedrock.MEAI.NetStandard", "src\AWSSDK.Extensions.Bedrock.MEAI\AWSSDK.Extensions.Bedrock.MEAI.NetStandard.csproj", "{B174860A-0D1B-4B7D-9E46-7DBFC9AA5AAB}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -102,10 +112,22 @@ Global {EA6EEC77-E69B-4D42-B9F2-BADCEEE5A32B}.Debug|Any CPU.Build.0 = Debug|Any CPU {EA6EEC77-E69B-4D42-B9F2-BADCEEE5A32B}.Release|Any CPU.ActiveCfg = Release|Any CPU {EA6EEC77-E69B-4D42-B9F2-BADCEEE5A32B}.Release|Any CPU.Build.0 = Release|Any CPU + {B5244288-5997-4E72-8AD8-936D346C02CE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B5244288-5997-4E72-8AD8-936D346C02CE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B5244288-5997-4E72-8AD8-936D346C02CE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B5244288-5997-4E72-8AD8-936D346C02CE}.Release|Any CPU.Build.0 = Release|Any CPU + {D98D6380-80A3-4818-84B4-3BD332383CA2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D98D6380-80A3-4818-84B4-3BD332383CA2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D98D6380-80A3-4818-84B4-3BD332383CA2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D98D6380-80A3-4818-84B4-3BD332383CA2}.Release|Any CPU.Build.0 = Release|Any CPU {280223DF-ECB0-4B38-A3A6-B80B46D48475}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {280223DF-ECB0-4B38-A3A6-B80B46D48475}.Debug|Any CPU.Build.0 = Debug|Any CPU {280223DF-ECB0-4B38-A3A6-B80B46D48475}.Release|Any CPU.ActiveCfg = Release|Any CPU {280223DF-ECB0-4B38-A3A6-B80B46D48475}.Release|Any CPU.Build.0 = Release|Any CPU + {71C8FC92-F868-4E07-B005-62180C1D6B8B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {71C8FC92-F868-4E07-B005-62180C1D6B8B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {71C8FC92-F868-4E07-B005-62180C1D6B8B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {71C8FC92-F868-4E07-B005-62180C1D6B8B}.Release|Any CPU.Build.0 = Release|Any CPU {FC70CF98-BA7E-4F9F-A5DB-966973284091}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {FC70CF98-BA7E-4F9F-A5DB-966973284091}.Debug|Any CPU.Build.0 = Debug|Any CPU {FC70CF98-BA7E-4F9F-A5DB-966973284091}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -114,6 +136,10 @@ Global {4FFF9872-1D77-4664-83C6-B46AC6EB1E20}.Debug|Any CPU.Build.0 = Debug|Any CPU {4FFF9872-1D77-4664-83C6-B46AC6EB1E20}.Release|Any CPU.ActiveCfg = Release|Any CPU {4FFF9872-1D77-4664-83C6-B46AC6EB1E20}.Release|Any CPU.Build.0 = Release|Any CPU + {B416F870-421E-410A-8848-13A7F523E669}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B416F870-421E-410A-8848-13A7F523E669}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B416F870-421E-410A-8848-13A7F523E669}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B416F870-421E-410A-8848-13A7F523E669}.Release|Any CPU.Build.0 = Release|Any CPU {0377B228-91F3-4A0B-BE66-221E7ECA6DF7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {0377B228-91F3-4A0B-BE66-221E7ECA6DF7}.Debug|Any CPU.Build.0 = Debug|Any CPU {0377B228-91F3-4A0B-BE66-221E7ECA6DF7}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -126,6 +152,14 @@ Global {3EC669E6-A541-445E-B68E-0A853715E39C}.Debug|Any CPU.Build.0 = Debug|Any CPU {3EC669E6-A541-445E-B68E-0A853715E39C}.Release|Any CPU.ActiveCfg = Release|Any CPU {3EC669E6-A541-445E-B68E-0A853715E39C}.Release|Any CPU.Build.0 = Release|Any CPU + {4A94F623-0C71-47BD-B927-CB6FA28D33A1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4A94F623-0C71-47BD-B927-CB6FA28D33A1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4A94F623-0C71-47BD-B927-CB6FA28D33A1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4A94F623-0C71-47BD-B927-CB6FA28D33A1}.Release|Any CPU.Build.0 = Release|Any CPU + {B174860A-0D1B-4B7D-9E46-7DBFC9AA5AAB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B174860A-0D1B-4B7D-9E46-7DBFC9AA5AAB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B174860A-0D1B-4B7D-9E46-7DBFC9AA5AAB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B174860A-0D1B-4B7D-9E46-7DBFC9AA5AAB}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -143,12 +177,18 @@ Global {C8A027AB-282C-400E-893D-971A5D55DB17} = {3D822DC2-ED2E-4434-BC4F-CE7FCD846B02} {A552BA51-D17C-4594-BF0A-DF7F53EA688D} = {A960D001-40B3-4B1A-A890-D1049FB7586E} {EA6EEC77-E69B-4D42-B9F2-BADCEEE5A32B} = {A960D001-40B3-4B1A-A890-D1049FB7586E} + {B5244288-5997-4E72-8AD8-936D346C02CE} = {A960D001-40B3-4B1A-A890-D1049FB7586E} + {D98D6380-80A3-4818-84B4-3BD332383CA2} = {A960D001-40B3-4B1A-A890-D1049FB7586E} {280223DF-ECB0-4B38-A3A6-B80B46D48475} = {0BA39F07-84D6-420B-82D3-6DC3AF016C65} + {71C8FC92-F868-4E07-B005-62180C1D6B8B} = {0BA39F07-84D6-420B-82D3-6DC3AF016C65} {FC70CF98-BA7E-4F9F-A5DB-966973284091} = {0BA39F07-84D6-420B-82D3-6DC3AF016C65} {4FFF9872-1D77-4664-83C6-B46AC6EB1E20} = {0BA39F07-84D6-420B-82D3-6DC3AF016C65} + {B416F870-421E-410A-8848-13A7F523E669} = {0BA39F07-84D6-420B-82D3-6DC3AF016C65} {0377B228-91F3-4A0B-BE66-221E7ECA6DF7} = {0BA39F07-84D6-420B-82D3-6DC3AF016C65} {E195094D-5899-4FDF-969D-93C4432BA921} = {3D822DC2-ED2E-4434-BC4F-CE7FCD846B02} {3EC669E6-A541-445E-B68E-0A853715E39C} = {3D822DC2-ED2E-4434-BC4F-CE7FCD846B02} + {4A94F623-0C71-47BD-B927-CB6FA28D33A1} = {3D822DC2-ED2E-4434-BC4F-CE7FCD846B02} + {B174860A-0D1B-4B7D-9E46-7DBFC9AA5AAB} = {3D822DC2-ED2E-4434-BC4F-CE7FCD846B02} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {949367A4-5683-4FD3-93F4-A2CEA6EECB21} diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.NetFramework.csproj b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.NetFramework.csproj new file mode 100644 index 000000000000..84228853072e --- /dev/null +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.NetFramework.csproj @@ -0,0 +1,46 @@ + + + net472 + AWSSDK.Extensions.Bedrock.MEAI + AWSSDK.Extensions.Bedrock.MEAI + + false + false + false + false + false + false + false + false + true + + Latest + enable + + + + + + + + + + ..\..\..\sdk\awssdk.dll.snk + + + + + $(AWSKeyFile) + + + + + + + + + + + + + diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.NetStandard.csproj b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.NetStandard.csproj new file mode 100644 index 000000000000..56fbdcbf3921 --- /dev/null +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.NetStandard.csproj @@ -0,0 +1,50 @@ + + + netstandard2.0;net8.0 + AWSSDK.Extensions.Bedrock.MEAI + AWSSDK.Extensions.Bedrock.MEAI + + false + false + false + false + false + false + false + false + true + + Latest + enable + + + + true + + + + + + + + + + ..\..\..\sdk\awssdk.dll.snk + + + + + $(AWSKeyFile) + + + + + + + + + + + + + diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.nuspec b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.nuspec new file mode 100644 index 000000000000..dc8519315959 --- /dev/null +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.nuspec @@ -0,0 +1,44 @@ + + + + AWSSDK.Extensions.Bedrock.MEAI + AWSSDK - Bedrock integration with Microsoft.Extensions.AI. + 4.0.0.0-preview.4 + Amazon Web Services + Implementations of Microsoft.Extensions.AI's abstractions for Bedrock. + en-US + Apache-2.0 + https://github.com/aws/aws-sdk-net/ + AWS Amazon aws-sdk-v4 + images\AWSLogo.png + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AmazonBedrockRuntimeExtensions.cs b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AmazonBedrockRuntimeExtensions.cs new file mode 100644 index 000000000000..6443fb96cc78 --- /dev/null +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AmazonBedrockRuntimeExtensions.cs @@ -0,0 +1,54 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using Microsoft.Extensions.AI; +using System; + +namespace Amazon.BedrockRuntime; + +/// Provides extensions for working with instances. +public static class AmazonBedrockRuntimeExtensions +{ + /// Gets an for the specified instance. + /// The runtime instance to be represented as an . + /// + /// The default model ID to use when no model is specified in a request. If not specified, + /// a model must be provided in the passed to + /// or . + /// + /// A instance representing the instance. + /// is . + public static IChatClient AsChatClient(this IAmazonBedrockRuntime runtime, string? modelId = null) => + runtime is not null ? new BedrockChatClient(runtime, modelId) : + throw new ArgumentNullException(nameof(runtime)); + + /// Gets an for the specified instance. + /// The runtime instance to be represented as an . + /// + /// The default model ID to use when no model is specified in a request. If not specified, + /// a model must be provided in the passed to + /// or . + /// + /// + /// The default number of dimensions to request be generated. This will be overridden by a + /// if that is specified to a request. If neither is specified, the default for the model will be used. + /// + /// A instance representing the instance. + /// is . + public static IEmbeddingGenerator> AsEmbeddingGenerator( + this IAmazonBedrockRuntime runtime, string? modelId = null, int? dimensions = null) => + runtime is not null ? new BedrockEmbeddingGenerator(runtime, modelId, dimensions) : + throw new ArgumentNullException(nameof(runtime)); +} diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs new file mode 100644 index 000000000000..3553692c6d4d --- /dev/null +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs @@ -0,0 +1,611 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Documents; +using Microsoft.Extensions.AI; +using System; +using System.Collections.Generic; +using System.Diagnostics; +#if NET8_0_OR_GREATER +using System.Diagnostics.CodeAnalysis; +#endif +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.BedrockRuntime; + +internal sealed partial class BedrockChatClient : IChatClient +{ + /// The wrapped instance. + private readonly IAmazonBedrockRuntime _runtime; + /// Default model ID to use when no model is specified in the request. + private readonly string? _modelId; + + /// + /// Initializes a new instance of the class. + /// + /// The instance to wrap. + /// Model ID to use as the default when no model ID is specified in a request. + public BedrockChatClient(IAmazonBedrockRuntime runtime, string? modelId) + { + Debug.Assert(runtime is not null); + + _runtime = runtime!; + _modelId = modelId; + + Metadata = new(runtime!.Config.ServiceId, modelId: modelId); + } + + public void Dispose() + { + // Do not dispose of _runtime, as this instance doesn't own it. + } + + /// + public ChatClientMetadata Metadata { get; } + + /// + public async Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + if (chatMessages is null) + { + throw new ArgumentNullException(nameof(chatMessages)); + } + + ConverseRequest request = new() + { + ModelId = options?.ModelId ?? _modelId, + Messages = CreateMessages(chatMessages), + System = CreateSystem(chatMessages), + ToolConfig = CreateToolConfig(options), + InferenceConfig = CreateInferenceConfiguration(options), + AdditionalModelRequestFields = CreateAdditionalModelRequestFields(options), + }; + + var response = await _runtime.ConverseAsync(request, cancellationToken).ConfigureAwait(false); + + ChatMessage result = new() + { + Role = ChatRole.Assistant, + }; + + if (response.Output?.Message?.Content is { } contents) + { + foreach (var content in contents) + { + if (content.Text is string text) + { + result.Contents.Add(new TextContent(text)); + } + + if (content.Image is { Source.Bytes: { } bytes, Format.Value: { } formatValue }) + { + result.Contents.Add(new ImageContent(bytes.ToArray(), $"image/{formatValue}")); + } + + if (content.ToolUse is { } toolUse) + { + result.Contents.Add(new FunctionCallContent(toolUse.ToolUseId, toolUse.Name, DocumentToDictionary(toolUse.Input))); + } + } + } + + if (DocumentToDictionary(response.AdditionalModelResponseFields) is { } responseFieldsDictionary) + { + result.AdditionalProperties = new(responseFieldsDictionary); + } + + return new ChatCompletion(result) + { + FinishReason = response.StopReason is not null ? GetChatFinishReason(response.StopReason) : null, + Usage = response.Usage is TokenUsage usage ? new() + { + InputTokenCount = usage.InputTokens, + OutputTokenCount = usage.OutputTokens, + TotalTokenCount = usage.TotalTokens, + } : null, + }; + } + + /// + public async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (chatMessages is null) + { + throw new ArgumentNullException(nameof(chatMessages)); + } + + ConverseStreamRequest request = new() + { + ModelId = options?.ModelId ?? _modelId, + Messages = CreateMessages(chatMessages), + System = CreateSystem(chatMessages), + ToolConfig = CreateToolConfig(options), + InferenceConfig = CreateInferenceConfiguration(options), + AdditionalModelRequestFields = CreateAdditionalModelRequestFields(options), + }; + + var result = await _runtime.ConverseStreamAsync(request, cancellationToken).ConfigureAwait(false); + + string? toolName = null; + string? toolId = null; + StringBuilder? toolInput = null; + ChatFinishReason? finishReason = null; + await foreach (var update in result.Stream.ConfigureAwait(false)) + { + switch (update) + { + case MessageStartEvent messageStart: + yield return new() + { + Role = ChatRole.Assistant, + FinishReason = finishReason, + }; + break; + + case ContentBlockStartEvent contentBlockStart when contentBlockStart?.Start?.ToolUse is ToolUseBlockStart tubs: + toolName ??= tubs.Name; + toolId ??= tubs.ToolUseId; + break; + + case ContentBlockDeltaEvent contentBlockDelta when contentBlockDelta.Delta is not null: + if (contentBlockDelta.Delta.ToolUse is ToolUseBlockDelta tubd) + { + (toolInput ??= new()).Append(tubd.Input); + } + + if (contentBlockDelta.Delta.Text is string text) + { + yield return new() + { + Role = ChatRole.Assistant, + FinishReason = finishReason, + Text = text, + }; + } + break; + + case ContentBlockStopEvent contentBlockStop: + if (toolName is not null && toolId is not null) + { + Dictionary? inputs = ParseToolInputs(toolInput?.ToString(), out Exception? parseError); + yield return new() + { + Role = ChatRole.Assistant, + FinishReason = finishReason, + Contents = [new FunctionCallContent(toolId, toolName, inputs) { Exception = parseError }], + }; + } + + toolName = null; + toolId = null; + toolInput = null; + break; + + case MessageStopEvent messageStop: + if (messageStop.StopReason is not null) + { + finishReason ??= GetChatFinishReason(messageStop.StopReason); + } + + AdditionalPropertiesDictionary? additionalProps = null; + if (DocumentToDictionary(messageStop.AdditionalModelResponseFields) is { } responseFieldsDictionary) + { + additionalProps = new(responseFieldsDictionary); + } + + yield return new() + { + Role = ChatRole.Assistant, + FinishReason = finishReason, + AdditionalProperties = additionalProps, + }; + break; + + case ConverseStreamMetadataEvent metadata when metadata.Usage is TokenUsage usage: + yield return new() + { + Role = ChatRole.Assistant, + FinishReason = finishReason, + Contents = + [ + new UsageContent(new() + { + InputTokenCount = usage.InputTokens, + OutputTokenCount = usage.OutputTokens, + TotalTokenCount = usage.TotalTokens, + }) + ], + }; + break; + } + } + } + + /// + public object? GetService(Type serviceType, object? key) + { + if (serviceType is null) + { + throw new ArgumentNullException(nameof(serviceType)); + } + + return + key is not null ? null : + serviceType.IsInstanceOfType(_runtime) ? _runtime : + serviceType.IsInstanceOfType(this) ? this : + null; + } + + /// Converts a into a . + private static ChatFinishReason GetChatFinishReason(StopReason stopReason) => + stopReason.Value switch + { + "content_filtered" => ChatFinishReason.ContentFilter, + "guardrail_intervened" => ChatFinishReason.ContentFilter, + "end_turn" => ChatFinishReason.Stop, + "max_tokens" => ChatFinishReason.Length, + "stop_sequence" => ChatFinishReason.Stop, + "tool_use" => ChatFinishReason.ToolCalls, + _ => new(stopReason.Value), + }; + + /// Creates a list of from the system messages in the provided . + private static List CreateSystem(IList chatMessages) => + chatMessages + .Where(m => m.Role == ChatRole.System && m.Contents.Any(c => c is TextContent)) + .Select(m => new SystemContentBlock() { Text = string.Concat(m.Contents.OfType()) }) + .ToList(); + + /// Parses JSON tool input into a . + private static Dictionary? ParseToolInputs(string? jsonInput, out Exception? parseError) + { + parseError = null; + if (jsonInput is not null) + { + try + { + return (Dictionary?)JsonSerializer.Deserialize(jsonInput, BedrockJsonContext.DefaultOptions.GetTypeInfo(typeof(Dictionary))); + } + catch (Exception e) + { + parseError = new InvalidOperationException($"Unable to parse input: {jsonInput}", e); + } + } + + return null; + } + + /// Creates a list of from the provided . + private static List CreateMessages(IList chatMessages) + { + List messages = []; + + foreach (ChatMessage chatMessage in chatMessages) + { + if (chatMessage.Role == ChatRole.System) + { + continue; + } + + messages.Add(new() + { + Role = chatMessage.Role == ChatRole.Assistant ? ConversationRole.Assistant : ConversationRole.User, + Content = CreateContents(chatMessage), + }); + } + + return messages; + } + + /// Creates a list of s from a . + private static List CreateContents(ChatMessage message) + { + List contents = []; + + foreach (AIContent content in message.Contents) + { + switch (content) + { + case TextContent tc: + contents.Add(new() { Text = tc.Text }); + break; + + case ImageContent ic when ic.ContainsData: + contents.Add(new() + { + Image = new() + { + Source = new() { Bytes = new(ic.Data!.Value.ToArray()) }, + Format = ic.MediaType switch + { + "image/jpeg" => ImageFormat.Jpeg, + "image/png" => ImageFormat.Png, + "image/gif" => ImageFormat.Gif, + "image/webp" => ImageFormat.Webp, + _ => null, + }, + } + }); + break; + + case FunctionCallContent fcc: + contents.Add(new() + { + ToolUse = new() + { + ToolUseId = fcc.CallId, + Name = fcc.Name, + Input = DictionaryToDocument(fcc.Arguments), + } + }); + break; + + case FunctionResultContent frc: + Document result = frc.Result switch + { + int i => i, + long l => l, + float f => f, + double d => d, + string s => s, + bool b => b, + JsonElement json => ToDocument(json), + { } other => ToDocument(JsonSerializer.SerializeToElement(other, BedrockJsonContext.DefaultOptions.GetTypeInfo(other.GetType()))), + _ => default, + }; + + contents.Add(new() + { + ToolResult = new() + { + ToolUseId = frc.CallId, + Content = [new() { Json = new Document(new Dictionary() { ["result"] = result }) }], + }, + }); + break; + } + } + + return contents; + } + + /// Converts a to a . + private static Document DictionaryToDocument(IDictionary? arguments) + { + Document inputs = default; + if (arguments is not null) + { + foreach (KeyValuePair argument in arguments) + { + switch (argument.Value) + { + case bool argumentBool: inputs.Add(argument.Key, argumentBool); break; + case int argumentInt32: inputs.Add(argument.Key, argumentInt32); break; + case long argumentInt64: inputs.Add(argument.Key, argumentInt64); break; + case float argumentSingle: inputs.Add(argument.Key, argumentSingle); break; + case double argumentDouble: inputs.Add(argument.Key, argumentDouble); break; + case string argumentString: inputs.Add(argument.Key, argumentString); break; + case JsonElement json: inputs.Add(argument.Key, ToDocument(json)); break; + } + } + } + + return inputs; + } + + /// Converts a to a . + private static Dictionary? DocumentToDictionary(Document d) + { + if (d.IsDictionary()) + { + return (Dictionary?) + DocumentDictionaryToNode(d.AsDictionary()) + .Deserialize(BedrockJsonContext.DefaultOptions.GetTypeInfo(typeof(Dictionary))); + } + + return null; + } + + /// Converts a to a . + private static JsonObject DocumentDictionaryToNode(Dictionary documentDictionary) => + new(documentDictionary.Select(entry => new KeyValuePair(entry.Key, DocumentToNode(entry.Value)))); + + /// Converts a to a . + private static JsonNode? DocumentToNode(Document value) => + value.IsBool() ? value.AsBool() : + value.IsInt() ? value.AsInt() : + value.IsLong() ? value.AsLong() : + value.IsDouble() ? value.AsDouble() : + value.IsString() ? value.AsString() : + value.IsList() ? new JsonArray(value.AsList().Select(DocumentToNode).ToArray()) : + value.IsDictionary() ? DocumentDictionaryToNode(value.AsDictionary()) : + null; + + /// Converts a to a . + private static Document ToDocument(JsonElement json) + { + switch (json.ValueKind) + { + case JsonValueKind.String: + return json.GetString(); + + case JsonValueKind.Number: + return json.GetDouble(); + + case JsonValueKind.True: + return true; + + case JsonValueKind.False: + return false; + + case JsonValueKind.Array: + var elements = new Document[json.GetArrayLength()]; + for (int i = 0; i < elements.Length; i++) + { + elements[i] = ToDocument(json[i]); + } + return elements; + + case JsonValueKind.Object: + Dictionary props = []; + foreach (var prop in json.EnumerateObject()) + { + props.Add(prop.Name, ToDocument(prop.Value)); + } + return props; + + case JsonValueKind.Null: + default: + return string.Empty; + } + } + + /// Creates an from the specified options. + private static ToolConfiguration? CreateToolConfig(ChatOptions? options) + { + List? tools = options?.Tools?.OfType().Select(f => + { + Document inputs = default; + List required = []; + + foreach (var parameter in f.Metadata.Parameters) + { + inputs.Add(parameter.Name, parameter.Schema is JsonElement schema ? ToDocument(schema) : new Document(true)); + if (parameter.IsRequired) + { + required.Add(parameter.Name); + } + } + + return new Tool() + { + ToolSpec = new ToolSpecification() + { + Name = f.Metadata.Name, + Description = !string.IsNullOrEmpty(f.Metadata.Description) ? f.Metadata.Description : f.Metadata.Name, + InputSchema = new() + { + Json = new(new Dictionary() + { + ["type"] = new Document("object"), + ["properties"] = inputs, + ["required"] = new Document(required), + }) + }, + }, + }; + }).ToList(); + + ToolChoice? choice = null; + if (tools is { Count: > 0 }) + { + switch (options!.ToolMode) + { + case AutoChatToolMode: + choice = new ToolChoice() { Auto = new() }; + break; + + case RequiredChatToolMode r: + choice = !string.IsNullOrWhiteSpace(r.RequiredFunctionName) ? + new ToolChoice() { Tool = new() { Name = r.RequiredFunctionName } } : + new ToolChoice() { Any = new() }; + break; + } + + return new() + { + ToolChoice = choice, + Tools = tools, + }; + } + + return null; + } + + /// Creates an from the specified options. + private static InferenceConfiguration CreateInferenceConfiguration(ChatOptions? options) => + new() + { + MaxTokens = options?.MaxOutputTokens, + StopSequences = options?.StopSequences?.ToList(), + Temperature = options?.Temperature, + TopP = options?.TopP, + }; + + /// Creates a from the specified options to use as the additional model request options. + private static Document CreateAdditionalModelRequestFields(ChatOptions? options) + { + Document d = default; + + if (options is not null) + { + if (options.TopK is int topK) + { + d.Add("k", topK); + } + + if (options.FrequencyPenalty is float frequencyPenalty) + { + d.Add("frequency_penalty", frequencyPenalty); + } + + if (options.PresencePenalty is float presencePenalty) + { + d.Add("presence_penalty", presencePenalty); + } + + if (options.Seed is long seed) + { + d.Add("seed", seed); + } + + if (options.AdditionalProperties is { } props) + { + foreach (KeyValuePair prop in props) + { + switch (prop.Value) + { + case bool propBool: d.Add(prop.Key, propBool); break; + case int propInt32: d.Add(prop.Key, propInt32); break; + case long propInt64: d.Add(prop.Key, propInt64); break; + case float propSingle: d.Add(prop.Key, propSingle); break; + case double propDouble: d.Add(prop.Key, propDouble); break; + case string propString: d.Add(prop.Key, propString); break; + case null: d.Add(prop.Key, default); break; + case JsonElement json: d.Add(prop.Key, ToDocument(json)); break; + default: + try + { + d.Add(prop.Key, ToDocument(JsonSerializer.SerializeToElement(prop.Value, BedrockJsonContext.DefaultOptions.GetTypeInfo(prop.Value.GetType())))); + } + catch { } + break; + } + } + } + } + + return d; + } +} diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockEmbeddingGenerator.cs b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockEmbeddingGenerator.cs new file mode 100644 index 000000000000..162f91dd17bb --- /dev/null +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockEmbeddingGenerator.cs @@ -0,0 +1,125 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using Microsoft.Extensions.AI; +using System; +using System.Collections.Generic; +using System.Diagnostics; +#if NET8_0_OR_GREATER +using System.Diagnostics.CodeAnalysis; +#endif +using System.IO; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.BedrockRuntime; + +internal sealed partial class BedrockEmbeddingGenerator : IEmbeddingGenerator> +{ + /// The wrapped instance. + private readonly IAmazonBedrockRuntime _runtime; + /// Default model ID to use when no model is specified in the request. + private readonly string? _modelId; + /// Default number of dimensions to use when no number of dimensions is specified in the request. + private readonly int? _dimensions; + + /// + /// Initializes a new instance of the class. + /// + /// The instance to wrap. + /// Model ID to use as the default when no model ID is specified in a request. + /// Number of dimensions to use when no number of dimensions is specified in a request. + public BedrockEmbeddingGenerator(IAmazonBedrockRuntime runtime, string? modelId, int? dimensions) + { + Debug.Assert(runtime is not null); + + _runtime = runtime!; + _modelId = modelId; + _dimensions = dimensions; + + Metadata = new(runtime!.Config.ServiceId, modelId: modelId, dimensions: dimensions); + } + + public void Dispose() + { + // Do not dispose of _runtime, as this instance doesn't own it. + } + + /// + public EmbeddingGeneratorMetadata Metadata { get; } + + /// + public object? GetService(Type serviceType, object? key) + { + if (serviceType is null) + { + throw new ArgumentNullException(nameof(serviceType)); + } + + return + key is not null ? null : + serviceType.IsInstanceOfType(_runtime) ? _runtime : + serviceType.IsInstanceOfType(this) ? this : + null; + } + + /// + public async Task>> GenerateAsync( + IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + if (values is null) + { + throw new ArgumentNullException(nameof(values)); + } + + GeneratedEmbeddings> embeddings = []; + int? totaltokens = null; + + foreach (string value in values) + { + var response = await _runtime.InvokeModelAsync(new() + { + ModelId = options?.ModelId ?? _modelId, + Accept = "application/json", + ContentType = "application/json", + Body = new MemoryStream(JsonSerializer.SerializeToUtf8Bytes(new EmbeddingRequest() + { + InputText = value, + Dimensions = options?.Dimensions ?? _dimensions, + }, BedrockJsonContext.Default.EmbeddingRequest)), + }, cancellationToken).ConfigureAwait(false); + + var er = JsonSerializer.Deserialize(response.Body, BedrockJsonContext.Default.EmbeddingResponse); + if (er?.Embedding is not null) + { + embeddings.Add(new(er.Embedding)); + + if (er.InputTextTokenCount is int inputTokens) + { + totaltokens ??= 0; + totaltokens += inputTokens; + } + } + } + + if (totaltokens is not null) + { + embeddings.Usage = new() { InputTokenCount = totaltokens.Value }; + } + + return embeddings; + } +} \ No newline at end of file diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockJsonContext.cs b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockJsonContext.cs new file mode 100644 index 000000000000..b5a04ba29e3d --- /dev/null +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockJsonContext.cs @@ -0,0 +1,77 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System.Collections.Generic; +#if NET8_0_OR_GREATER +using System.Diagnostics.CodeAnalysis; +#endif +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; + +namespace Amazon.BedrockRuntime; + +/// Provides type information for use with . +[JsonSourceGenerationOptions(JsonSerializerDefaults.Web, + UseStringEnumConverter = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + WriteIndented = true)] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(IDictionary))] +[JsonSerializable(typeof(bool))] +[JsonSerializable(typeof(int))] +[JsonSerializable(typeof(long))] +[JsonSerializable(typeof(float))] +[JsonSerializable(typeof(double))] +[JsonSerializable(typeof(string))] +[JsonSerializable(typeof(JsonElement))] +[JsonSerializable(typeof(JsonNode))] +[JsonSerializable(typeof(EmbeddingRequest))] +[JsonSerializable(typeof(EmbeddingResponse))] +internal partial class BedrockJsonContext : JsonSerializerContext +{ + /// Gets the singleton used as the default in JSON serialization operations. + public static readonly JsonSerializerOptions DefaultOptions = CreateDefaultToolJsonOptions(); + + /// Creates the default to use for serialization-related operations. +#if NET8_0_OR_GREATER + [UnconditionalSuppressMessage("AotAnalysis", "IL3050", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")] +#endif + private static JsonSerializerOptions CreateDefaultToolJsonOptions() + { + // If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize, + // and we want to be flexible in terms of what can be put into the various collections in the object model. + // Otherwise, use the source-generated options to enable trimming and Native AOT. + + if (JsonSerializer.IsReflectionEnabledByDefault) + { + // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above. + JsonSerializerOptions options = new(JsonSerializerDefaults.Web) + { + TypeInfoResolver = new DefaultJsonTypeInfoResolver(), + Converters = { new JsonStringEnumConverter() }, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + WriteIndented = true, + }; + + options.MakeReadOnly(); + return options; + } + + return Default.Options; + } +} \ No newline at end of file diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/Directory.Build.props b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/Directory.Build.props new file mode 100644 index 000000000000..95db0bd484df --- /dev/null +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/Directory.Build.props @@ -0,0 +1,6 @@ + + + + $(MSBuildProjectDirectory)\obj\$(MSBuildProjectName) + + \ No newline at end of file diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/EmbeddingRequest.cs b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/EmbeddingRequest.cs new file mode 100644 index 000000000000..70aeb6bb855b --- /dev/null +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/EmbeddingRequest.cs @@ -0,0 +1,40 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System.Text.Json.Serialization; + +#if NET8_0_OR_GREATER +using System.Diagnostics.CodeAnalysis; +#endif + +namespace Amazon.BedrockRuntime; + +internal sealed class EmbeddingRequest +{ + [JsonPropertyName("inputText")] + public string? InputText { get; set; } + + [JsonPropertyName("dimensions")] + public int? Dimensions { get; set; } +} + +internal sealed class EmbeddingResponse +{ + [JsonPropertyName("embedding")] + public float[]? Embedding { get; set; } + + [JsonPropertyName("inputTextTokenCount")] + public int? InputTextTokenCount { get; set; } +} \ No newline at end of file diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/Properties/AssemblyInfo.cs b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/Properties/AssemblyInfo.cs new file mode 100644 index 000000000000..9f33a53599e7 --- /dev/null +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/Properties/AssemblyInfo.cs @@ -0,0 +1,23 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +[assembly: AssemblyTitle("AWSSDK.Extensions.Bedrock.MEAI")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("Amazon.com, Inc")] +[assembly: AssemblyProduct("AWS SDK for .NET extensions for Bedrock integrating with Microsoft.Extensions.AI")] +[assembly: AssemblyDescription("AWS SDK for .NET extensions for Bedrock integrating with Microsoft.Extensions.AI")] +[assembly: AssemblyCopyright("Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.")] +[assembly: AssemblyTrademark("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +#if NETFRAMEWORK +[assembly: AssemblyVersion("4.0")] +#else +[assembly: AssemblyVersion("4.0.0.0")] +#endif +[assembly: AssemblyFileVersion("4.0.0.0")] diff --git a/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs b/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs new file mode 100644 index 000000000000..5d2bb7e2e497 --- /dev/null +++ b/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs @@ -0,0 +1,47 @@ +using Microsoft.Extensions.AI; +using System; +using Xunit; + +namespace Amazon.BedrockRuntime; + +public class BedrockChatClientTests +{ + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public void AsChatClient_InvalidArguments_Throws() + { + Assert.Throws("runtime", () => AmazonBedrockRuntimeExtensions.AsChatClient(null)); + } + + [Theory] + [Trait("UnitTest", "BedrockRuntime")] + [InlineData(null)] + [InlineData("claude")] + public void AsChatClient_ReturnsInstance(string modelId) + { + IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1); + IChatClient client = runtime.AsChatClient(modelId); + + Assert.NotNull(client); + Assert.Equal("Bedrock Runtime", client.Metadata.ProviderName); + Assert.Equal(modelId, client.Metadata.ModelId); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public void AsChatClient_GetService() + { + IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1); + IChatClient client = runtime.AsChatClient(); + + Assert.Same(runtime, client.GetService()); + Assert.Same(runtime, client.GetService()); + Assert.Same(client, client.GetService()); + + Assert.Null(client.GetService()); + + Assert.Null(client.GetService("key")); + Assert.Null(client.GetService("key")); + Assert.Null(client.GetService("key")); + } +} diff --git a/extensions/test/BedrockMEAITests/BedrockEmbeddingGeneratorTests.cs b/extensions/test/BedrockMEAITests/BedrockEmbeddingGeneratorTests.cs new file mode 100644 index 000000000000..35a2732b2d7a --- /dev/null +++ b/extensions/test/BedrockMEAITests/BedrockEmbeddingGeneratorTests.cs @@ -0,0 +1,50 @@ +using Microsoft.Extensions.AI; +using System; +using Xunit; + +namespace Amazon.BedrockRuntime; + +public class BedrockEmbeddingGeneratorTests +{ + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public void AsEmbeddingGenerator_InvalidArguments_Throws() + { + Assert.Throws("runtime", () => AmazonBedrockRuntimeExtensions.AsEmbeddingGenerator(null)); + } + + [Theory] + [Trait("UnitTest", "BedrockRuntime")] + [InlineData(null, null)] + [InlineData("titan", null)] + [InlineData(null, 42)] + [InlineData("titan", 42)] + public void AsEmbeddingGenerator_ReturnsInstance(string modelId, int? dimensions) + { + IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1); + IEmbeddingGenerator> generator = runtime.AsEmbeddingGenerator(modelId, dimensions); + + Assert.NotNull(generator); + Assert.Equal("Bedrock Runtime", generator.Metadata.ProviderName); + Assert.Equal(modelId, generator.Metadata.ModelId); + Assert.Equal(dimensions, generator.Metadata.Dimensions); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public void AsEmbeddingGenerator_GetService() + { + IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1); + IEmbeddingGenerator> generator = runtime.AsEmbeddingGenerator(); + + Assert.Same(runtime, generator.GetService()); + Assert.Same(runtime, generator.GetService()); + Assert.Same(generator, generator.GetService>>()); + + Assert.Null(generator.GetService()); + + Assert.Null(generator.GetService("key")); + Assert.Null(generator.GetService("key")); + Assert.Null(generator.GetService>>("key")); + } +} diff --git a/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj b/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj new file mode 100644 index 000000000000..6773fb847e94 --- /dev/null +++ b/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj @@ -0,0 +1,35 @@ + + + net472 + BedrockMEAITests + BedrockMEAITests + + false + false + false + false + false + false + false + false + + true + Latest + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EnumerableEventStream.cs b/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EnumerableEventStream.cs index 30da7cffc453..d75ec08ed3d3 100644 --- a/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EnumerableEventStream.cs +++ b/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EnumerableEventStream.cs @@ -24,6 +24,8 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.Threading; + #if AWS_ASYNC_API using System.Threading.Tasks; #endif @@ -48,9 +50,9 @@ namespace Amazon.Runtime.EventStreams.Internal [SuppressMessage("Microsoft.Naming", "CA1710", Justification = "EventStreamCollection is not descriptive.")] [SuppressMessage("Microsoft.Design", "CA1063", Justification = "IDisposable is a transient interface from IEventStream. Users need to be able to call Dispose.")] #if NET8_0_OR_GREATER - public abstract class EnumerableEventStream : EventStream, IEnumerableEventStream where T : IEventStreamEvent where TE : EventStreamException, new() + public abstract class EnumerableEventStream : EventStream, IEnumerableEventStream, IAsyncEnumerable where T : IEventStreamEvent where TE : EventStreamException, new() #else - public abstract class EnumerableEventStream : EventStream, IEnumerableEventStream where T : IEventStreamEvent where TE : EventStreamException, new() + public abstract class EnumerableEventStream : EventStream, IEnumerableEventStream, IAsyncEnumerable where T : IEventStreamEvent where TE : EventStreamException, new() #endif { private const string MutuallyExclusiveExceptionMessage = "Stream has already begun processing. Event-driven and Enumerable traversals of the stream are mutually exclusive. " + @@ -145,6 +147,67 @@ public IEnumerator GetEnumerator() } } + /// + /// Returns an async enumerator that asynchronously iterates through the collection. + /// + /// An async enumerator that can be used to iterate through the collection. + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken) + { + // This implementation of this method is identical to that of GetEnumerator, except that + // instead of using ReadFromStream, it uses ReadFromStreamAsync. The two implementations + // should be kept in sync. + + if (IsProcessing) + { + // If the queue has already begun processing, refuse to enumerate. + throw new InvalidOperationException(MutuallyExclusiveExceptionMessage); + } + + // There could be more than 1 message created per decoder cycle. + var events = new Queue(); + + // Opting out of events - letting the enumeration handle everything. + IsEnumerated = true; + IsProcessing = true; + + // Enumeration is just magic over the event driven mechanism. + EventReceived += (sender, args) => events.Enqueue(args.EventStreamEvent); + + var buffer = new byte[BufferSize]; + + while (IsProcessing) + { + // If there are already events ready to be served, do not ask for more. + if (events.Count > 0) + { + var ev = events.Dequeue(); + // Enumeration handles terminal events on behalf of the user. + if (ev is IEventStreamTerminalEvent) + { + IsProcessing = false; + Dispose(); + } + + yield return ev; + } + else + { + try + { + await ReadFromStreamAsync(buffer, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + IsProcessing = false; + Dispose(); + + // Wrap exceptions as needed to match event-driven behavior. + throw WrapException(ex); + } + } + } + } + /// /// Returns an enumerator that iterates through a collection. /// diff --git a/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs b/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs index f7f715f7adb6..a7fee63cb083 100644 --- a/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs +++ b/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs @@ -17,6 +17,8 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.Threading; + #if AWS_ASYNC_API using System.Threading.Tasks; #else @@ -351,9 +353,21 @@ protected void ReadFromStream(byte[] buffer) /// each message it decodes. /// /// The buffer to store the read bytes from the stream. - protected async Task ReadFromStreamAsync(byte[] buffer) + protected Task ReadFromStreamAsync(byte[] buffer) => ReadFromStreamAsync(buffer, CancellationToken.None); + + /// + /// Reads from the stream into the buffer. It then passes the buffer to the decoder, which raises an event for + /// each message it decodes. + /// + /// The buffer to store the read bytes from the stream. + /// The token to monitor for cancellation requests. + protected async Task ReadFromStreamAsync(byte[] buffer, CancellationToken cancellationToken) { - var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false); +#if NETCOREAPP + var bytesRead = await NetworkStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); +#else + var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); +#endif if (bytesRead > 0) { // Decoder raises MessageReceived for every message it encounters. diff --git a/sdk/src/Services/BedrockRuntime/AWSSDK.BedrockRuntime.NetFramework.csproj b/sdk/src/Services/BedrockRuntime/AWSSDK.BedrockRuntime.NetFramework.csproj index 48249ff06b1d..aa1cb41eb5c6 100644 --- a/sdk/src/Services/BedrockRuntime/AWSSDK.BedrockRuntime.NetFramework.csproj +++ b/sdk/src/Services/BedrockRuntime/AWSSDK.BedrockRuntime.NetFramework.csproj @@ -1,4 +1,4 @@ - + true net472