Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Microsoft.Extensions.AI's IChatClient / IEmbeddingGenerator for IAmazonBedrockRuntime #3545

Open
wants to merge 8 commits into
base: v4-development
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;

#if AWS_ASYNC_API
using System.Threading.Tasks;
#endif
Expand All @@ -48,9 +50,9 @@ namespace Amazon.Runtime.EventStreams.Internal
[SuppressMessage("Microsoft.Naming", "CA1710", Justification = "EventStreamCollection is not descriptive.")]
[SuppressMessage("Microsoft.Design", "CA1063", Justification = "IDisposable is a transient interface from IEventStream. Users need to be able to call Dispose.")]
#if NET8_0_OR_GREATER
public abstract class EnumerableEventStream<T, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TE> : EventStream<T, TE>, IEnumerableEventStream<T, TE> where T : IEventStreamEvent where TE : EventStreamException, new()
public abstract class EnumerableEventStream<T, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TE> : EventStream<T, TE>, IEnumerableEventStream<T, TE>, IAsyncEnumerable<T> where T : IEventStreamEvent where TE : EventStreamException, new()
#else
public abstract class EnumerableEventStream<T, TE> : EventStream<T, TE>, IEnumerableEventStream<T, TE> where T : IEventStreamEvent where TE : EventStreamException, new()
public abstract class EnumerableEventStream<T, TE> : EventStream<T, TE>, IEnumerableEventStream<T, TE>, IAsyncEnumerable<T> where T : IEventStreamEvent where TE : EventStreamException, new()
#endif
{
private const string MutuallyExclusiveExceptionMessage = "Stream has already begun processing. Event-driven and Enumerable traversals of the stream are mutually exclusive. " +
Expand Down Expand Up @@ -145,6 +147,67 @@ public IEnumerator<T> GetEnumerator()
}
}

/// <summary>
/// Returns an async enumerator that asynchronously iterates through the collection.
/// </summary>
/// <returns>An async enumerator that can be used to iterate through the collection.</returns>
public async IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken)
{
// This implementation of this method is identical to that of GetEnumerator, except that
// instead of using ReadFromStream, it uses ReadFromStreamAsync. The two implementations
// should be kept in sync.

if (IsProcessing)
{
// If the queue has already begun processing, refuse to enumerate.
throw new InvalidOperationException(MutuallyExclusiveExceptionMessage);
}

// There could be more than 1 message created per decoder cycle.
var events = new Queue<T>();

// Opting out of events - letting the enumeration handle everything.
IsEnumerated = true;
IsProcessing = true;

// Enumeration is just magic over the event driven mechanism.
EventReceived += (sender, args) => events.Enqueue(args.EventStreamEvent);

var buffer = new byte[BufferSize];

while (IsProcessing)
{
// If there are already events ready to be served, do not ask for more.
if (events.Count > 0)
{
var ev = events.Dequeue();
// Enumeration handles terminal events on behalf of the user.
if (ev is IEventStreamTerminalEvent)
{
IsProcessing = false;
Dispose();
}

yield return ev;
}
else
{
try
{
await ReadFromStreamAsync(buffer, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
IsProcessing = false;
Dispose();

// Wrap exceptions as needed to match event-driven behavior.
throw WrapException(ex);
}
}
}
}

/// <summary>
/// Returns an enumerator that iterates through a collection.
/// </summary>
Expand Down
18 changes: 16 additions & 2 deletions sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;

#if AWS_ASYNC_API
using System.Threading.Tasks;
#else
Expand Down Expand Up @@ -351,9 +353,21 @@ protected void ReadFromStream(byte[] buffer)
/// each message it decodes.
/// </summary>
/// <param name="buffer">The buffer to store the read bytes from the stream.</param>
protected async Task ReadFromStreamAsync(byte[] buffer)
protected Task ReadFromStreamAsync(byte[] buffer) => ReadFromStreamAsync(buffer, CancellationToken.None);

/// <summary>
/// Reads from the stream into the buffer. It then passes the buffer to the decoder, which raises an event for
/// each message it decodes.
/// </summary>
/// <param name="buffer">The buffer to store the read bytes from the stream.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
protected async Task ReadFromStreamAsync(byte[] buffer, CancellationToken cancellationToken)
{
var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
#if NETCOREAPP
var bytesRead = await NetworkStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
#else
var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
#endif
if (bytesRead > 0)
{
// Decoder raises MessageReceived for every message it encounters.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<RunAnalyzersDuringBuild Condition="'$(RunAnalyzersDuringBuild)'==''">true</RunAnalyzersDuringBuild>
<TargetFramework>net472</TargetFramework>
Expand Down Expand Up @@ -64,6 +64,10 @@
<ProjectReference Include="../../Core/AWSSDK.Core.NetFramework.csproj"/>
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24525.1" />
</ItemGroup>

<ItemGroup Condition="$(RunAnalyzersDuringBuild)">
<PackageReference Include="Microsoft.CodeAnalysis.FxCopAnalyzers" Version="2.9.3">
<PrivateAssets>all</PrivateAssets>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
<ProjectReference Include="../../Core/AWSSDK.Core.NetStandard.csproj"/>
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24525.1" />
</ItemGroup>

<ItemGroup Condition="$(RunAnalyzersDuringBuild)">
<PackageReference Include="Microsoft.CodeAnalysis.FxCopAnalyzers" Version="2.9.3">
<PrivateAssets>all</PrivateAssets>
Expand Down
Loading