Skip to content

Commit

Permalink
.Net: Replaced IMemoryStore with IVectorStore in examples (#9833)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Replaced `IMemoryStore` usage with `IVectorStore` in Semantic Kernel
examples.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
dmytrostruk authored Nov 27, 2024
1 parent 8dadef2 commit e780d7b
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 92 deletions.
149 changes: 101 additions & 48 deletions dotnet/samples/Concepts/Caching/SemanticCachingWithFilters.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics;
using Azure.Identity;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB;
using Microsoft.SemanticKernel.Connectors.Redis;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Embeddings;

namespace Caching;

Expand All @@ -18,20 +18,17 @@ namespace Caching;
/// </summary>
public class SemanticCachingWithFilters(ITestOutputHelper output) : BaseTest(output)
{
/// <summary>
/// Similarity/relevance score, from 0 to 1, where 1 means exact match.
/// It's possible to change this value during testing to see how caching logic will behave.
/// </summary>
private const double SimilarityScore = 0.9;

/// <summary>
/// Executing similar requests two times using in-memory caching store to compare execution time and results.
/// Second execution is faster, because the result is returned from cache.
/// </summary>
[Fact]
public async Task InMemoryCacheAsync()
{
var kernel = GetKernelWithCache(_ => new VolatileMemoryStore());
var kernel = GetKernelWithCache(services =>
{
services.AddInMemoryVectorStore();
});

var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?");
var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?");
Expand All @@ -53,12 +50,15 @@ public async Task InMemoryCacheAsync()
/// <summary>
/// Executing similar requests two times using Redis caching store to compare execution time and results.
/// Second execution is faster, because the result is returned from cache.
/// How to run Redis on Docker locally: https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/docker/
/// How to run Redis on Docker locally: https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/docker/.
/// </summary>
[Fact]
public async Task RedisCacheAsync()
{
var kernel = GetKernelWithCache(_ => new RedisMemoryStore("localhost:6379", vectorSize: 1536));
var kernel = GetKernelWithCache(services =>
{
services.AddRedisVectorStore("localhost:6379");
});

var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?");
var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?");
Expand All @@ -84,10 +84,12 @@ public async Task RedisCacheAsync()
[Fact]
public async Task AzureCosmosDBMongoDBCacheAsync()
{
var kernel = GetKernelWithCache(_ => new AzureCosmosDBMongoDBMemoryStore(
TestConfiguration.AzureCosmosDbMongoDb.ConnectionString,
TestConfiguration.AzureCosmosDbMongoDb.DatabaseName,
new(dimensions: 1536)));
var kernel = GetKernelWithCache(services =>
{
services.AddAzureCosmosDBMongoDBVectorStore(
TestConfiguration.AzureCosmosDbMongoDb.ConnectionString,
TestConfiguration.AzureCosmosDbMongoDb.DatabaseName);
});

var result1 = await ExecuteAsync(kernel, "First run", "What's the tallest building in New York?");
var result2 = await ExecuteAsync(kernel, "Second run", "What is the highest building in New York City?");
Expand All @@ -110,27 +112,41 @@ public async Task AzureCosmosDBMongoDBCacheAsync()
/// <summary>
/// Returns <see cref="Kernel"/> instance with required registered services.
/// </summary>
private Kernel GetKernelWithCache(Func<IServiceProvider, IMemoryStore> cacheFactory)
private Kernel GetKernelWithCache(Action<IServiceCollection> configureVectorStore)
{
var builder = Kernel.CreateBuilder();

// Add Azure OpenAI chat completion service
builder.AddAzureOpenAIChatCompletion(
TestConfiguration.AzureOpenAI.ChatDeploymentName,
TestConfiguration.AzureOpenAI.Endpoint,
TestConfiguration.AzureOpenAI.ApiKey);

// Add Azure OpenAI text embedding generation service
builder.AddAzureOpenAITextEmbeddingGeneration(
TestConfiguration.AzureOpenAIEmbeddings.DeploymentName,
TestConfiguration.AzureOpenAIEmbeddings.Endpoint,
TestConfiguration.AzureOpenAIEmbeddings.ApiKey);

// Add memory store for caching purposes (e.g. in-memory, Redis, Azure Cosmos DB)
builder.Services.AddSingleton<IMemoryStore>(cacheFactory);
if (!string.IsNullOrWhiteSpace(TestConfiguration.AzureOpenAI.ApiKey))
{
// Add Azure OpenAI chat completion service
builder.AddAzureOpenAIChatCompletion(
TestConfiguration.AzureOpenAI.ChatDeploymentName,
TestConfiguration.AzureOpenAI.Endpoint,
TestConfiguration.AzureOpenAI.ApiKey);

// Add Azure OpenAI text embedding generation service
builder.AddAzureOpenAITextEmbeddingGeneration(
TestConfiguration.AzureOpenAIEmbeddings.DeploymentName,
TestConfiguration.AzureOpenAIEmbeddings.Endpoint,
TestConfiguration.AzureOpenAI.ApiKey);
}
else
{
// Add Azure OpenAI chat completion service
builder.AddAzureOpenAIChatCompletion(
TestConfiguration.AzureOpenAI.ChatDeploymentName,
TestConfiguration.AzureOpenAI.Endpoint,
new AzureCliCredential());

// Add Azure OpenAI text embedding generation service
builder.AddAzureOpenAITextEmbeddingGeneration(
TestConfiguration.AzureOpenAIEmbeddings.DeploymentName,
TestConfiguration.AzureOpenAIEmbeddings.Endpoint,
new AzureCliCredential());
}

// Add text memory service that will be used to generate embeddings and query/store data.
builder.Services.AddSingleton<ISemanticTextMemory, SemanticTextMemory>();
// Add vector store for caching purposes (e.g. in-memory, Redis, Azure Cosmos DB)
configureVectorStore(builder.Services);

// Add prompt render filter to query cache and check if rendered prompt was already answered.
builder.Services.AddSingleton<IPromptRenderFilter, PromptCacheFilter>();
Expand Down Expand Up @@ -164,7 +180,10 @@ public class CacheBaseFilter
/// <summary>
/// Filter which is executed during prompt rendering operation.
/// </summary>
public sealed class PromptCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IPromptRenderFilter
public sealed class PromptCacheFilter(
ITextEmbeddingGenerationService textEmbeddingGenerationService,
IVectorStore vectorStore)
: CacheBaseFilter, IPromptRenderFilter
{
public async Task OnPromptRenderAsync(PromptRenderContext context, Func<PromptRenderContext, Task> next)
{
Expand All @@ -174,20 +193,22 @@ public async Task OnPromptRenderAsync(PromptRenderContext context, Func<PromptRe
// Get rendered prompt
var prompt = context.RenderedPrompt!;

// Search for similar prompts in cache with provided similarity/relevance score
var searchResult = await semanticTextMemory.SearchAsync(
CollectionName,
prompt,
limit: 1,
minRelevanceScore: SimilarityScore).FirstOrDefaultAsync();
var promptEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(prompt);

var collection = vectorStore.GetCollection<string, CacheRecord>(CollectionName);
await collection.CreateCollectionIfNotExistsAsync();

// Search for similar prompts in cache.
var searchResults = await collection.VectorizedSearchAsync(promptEmbedding, new() { Top = 1 }, context.CancellationToken);
var searchResult = (await searchResults.Results.FirstOrDefaultAsync())?.Record;

// If result exists, return it.
if (searchResult is not null)
{
// Override function result. This will prevent calling LLM and will return result immediately.
context.Result = new FunctionResult(context.Function, searchResult.Metadata.AdditionalMetadata)
context.Result = new FunctionResult(context.Function, searchResult.Result)
{
Metadata = new Dictionary<string, object?> { [RecordIdKey] = searchResult.Metadata.Id }
Metadata = new Dictionary<string, object?> { [RecordIdKey] = searchResult.Id }
};
}
}
Expand All @@ -196,7 +217,10 @@ public async Task OnPromptRenderAsync(PromptRenderContext context, Func<PromptRe
/// <summary>
/// Filter which is executed during function invocation.
/// </summary>
public sealed class FunctionCacheFilter(ISemanticTextMemory semanticTextMemory) : CacheBaseFilter, IFunctionInvocationFilter
public sealed class FunctionCacheFilter(
ITextEmbeddingGenerationService textEmbeddingGenerationService,
IVectorStore vectorStore)
: CacheBaseFilter, IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
Expand All @@ -212,12 +236,22 @@ public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, F
// Get cache record id if result was cached previously or generate new id.
var recordId = context.Result.Metadata?.GetValueOrDefault(RecordIdKey, Guid.NewGuid().ToString()) as string;

// Generate prompt embedding.
var promptEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(context.Result.RenderedPrompt);

// Cache rendered prompt and LLM result.
await semanticTextMemory.SaveInformationAsync(
CollectionName,
context.Result.RenderedPrompt,
recordId!,
additionalMetadata: result.ToString());
var collection = vectorStore.GetCollection<string, CacheRecord>(CollectionName);
await collection.CreateCollectionIfNotExistsAsync();

var cacheRecord = new CacheRecord
{
Id = recordId!,
Prompt = context.Result.RenderedPrompt,
Result = result.ToString(),
PromptEmbedding = promptEmbedding
};

await collection.UpsertAsync(cacheRecord, cancellationToken: context.CancellationToken);
}
}
}
Expand Down Expand Up @@ -245,4 +279,23 @@ private async Task<FunctionResult> ExecuteAsync(Kernel kernel, string title, str
}

#endregion

#region Vector Store Record

private sealed class CacheRecord
{
[VectorStoreRecordKey]
public string Id { get; set; }

[VectorStoreRecordData]
public string Prompt { get; set; }

[VectorStoreRecordData]
public string Result { get; set; }

[VectorStoreRecordVector(Dimensions: 1536)]
public ReadOnlyMemory<float> PromptEmbedding { get; set; }
}

#endregion
}
49 changes: 35 additions & 14 deletions dotnet/samples/Concepts/Optimization/FrugalGPTWithFilters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

using System.Runtime.CompilerServices;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.InMemory;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.PromptTemplates.Handlebars;
using Microsoft.SemanticKernel.Services;

Expand Down Expand Up @@ -97,11 +98,11 @@ public async Task ReducePromptSizeAsync()

// Add few-shot prompt optimization filter.
// The filter uses in-memory store for vector similarity search and text embedding generation service to generate embeddings.
var memoryStore = new VolatileMemoryStore();
var vectorStore = new InMemoryVectorStore();
var textEmbeddingGenerationService = kernel.GetRequiredService<ITextEmbeddingGenerationService>();

// Register optimization filter.
kernel.PromptRenderFilters.Add(new FewShotPromptOptimizationFilter(memoryStore, textEmbeddingGenerationService));
kernel.PromptRenderFilters.Add(new FewShotPromptOptimizationFilter(vectorStore, textEmbeddingGenerationService));

// Get result again and compare the usage.
result = await kernel.InvokeAsync(function, arguments);
Expand Down Expand Up @@ -167,7 +168,7 @@ public async Task LLMCascadeAsync()
/// which are similar to original request.
/// </summary>
private sealed class FewShotPromptOptimizationFilter(
IMemoryStore memoryStore,
IVectorStore vectorStore,
ITextEmbeddingGenerationService textEmbeddingGenerationService) : IPromptRenderFilter
{
/// <summary>
Expand All @@ -176,7 +177,7 @@ private sealed class FewShotPromptOptimizationFilter(
private const int TopN = 5;

/// <summary>
/// Collection name to use in memory store.
/// Collection name to use in vector store.
/// </summary>
private const string CollectionName = "examples";

Expand All @@ -188,30 +189,38 @@ public async Task OnPromptRenderAsync(PromptRenderContext context, Func<PromptRe

if (examples is { Count: > 0 } && !string.IsNullOrEmpty(request))
{
var memoryRecords = new List<MemoryRecord>();
var exampleRecords = new List<ExampleRecord>();

// Generate embedding for each example.
var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(examples);

// Create memory record instances with example text and embedding.
// Create vector store record instances with example text and embedding.
for (var i = 0; i < examples.Count; i++)
{
memoryRecords.Add(MemoryRecord.LocalRecord(Guid.NewGuid().ToString(), examples[i], "description", embeddings[i]));
exampleRecords.Add(new ExampleRecord
{
Id = Guid.NewGuid().ToString(),
Example = examples[i],
ExampleEmbedding = embeddings[i]
});
}

// Create collection and upsert all memory records for search.
// Create collection and upsert all vector store records for search.
// It's possible to do it only once and re-use the same examples for future requests.
await memoryStore.CreateCollectionAsync(CollectionName);
await memoryStore.UpsertBatchAsync(CollectionName, memoryRecords).ToListAsync();
var collection = vectorStore.GetCollection<string, ExampleRecord>(CollectionName);
await collection.CreateCollectionIfNotExistsAsync(context.CancellationToken);

await collection.UpsertBatchAsync(exampleRecords, cancellationToken: context.CancellationToken).ToListAsync(context.CancellationToken);

// Generate embedding for original request.
var requestEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(request);
var requestEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(request, cancellationToken: context.CancellationToken);

// Find top N examples which are similar to original request.
var topNExamples = await memoryStore.GetNearestMatchesAsync(CollectionName, requestEmbedding, TopN).ToListAsync();
var searchResults = await collection.VectorizedSearchAsync(requestEmbedding, new() { Top = TopN }, cancellationToken: context.CancellationToken);
var topNExamples = (await searchResults.Results.ToListAsync(context.CancellationToken)).Select(l => l.Record).ToList();

// Override arguments to use only top N examples, which will be sent to LLM.
context.Arguments["Examples"] = topNExamples.Select(l => l.Item1.Metadata.Text);
context.Arguments["Examples"] = topNExamples.Select(l => l.Example);
}

// Continue prompt rendering operation.
Expand Down Expand Up @@ -305,4 +314,16 @@ public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessa
yield return new StreamingChatMessageContent(AuthorRole.Assistant, mockResult);
}
}

private sealed class ExampleRecord
{
[VectorStoreRecordKey]
public string Id { get; set; }

[VectorStoreRecordData]
public string Example { get; set; }

[VectorStoreRecordVector]
public ReadOnlyMemory<float> ExampleEmbedding { get; set; }
}
}
Loading

0 comments on commit e780d7b

Please sign in to comment.