Skip to content

Commit

Permalink
.Net: Remove Memory property from SKContext (#2079)
Browse files Browse the repository at this point in the history
### Motivation and Context
Different skills and functions require different types of memory. Some
skills might just need local, high-access stores while others need
remote, authenticated access. Some are powered by embedding (vectors)
indexed in the cloud, while others may be generating lightweight
embeddings locally, on the fly. We cannot assume that memory is "one
size fits all", or that even all the skills/functions in a single Plan
use the same memory provider.

### Description
Instead of surfacing one single 'Memory' instance to all skills, we're
peeling this back and simplifying somewhat. In this first pass, the
`Memory` property of SKContext is removed. The updated guidance will be
to instantiate Skill classes with the memory provider you'd like them to
have access to. This also opens the door for the TextMemorySkill to be
registered several times in the kernel (with different aliases),
granting access to different memory stores.

BREAKING CHANGE: though this is not expected to have a large impact,
updates will be required for anyone:
- using `SKContext.Memory`: will need to pass it into the function via
an alternate route; likely Skill constructor injection
- or those using the `new SKContext()` constructor and explicity passing
the optional ISemanticTextMemory: creating a new SKContext with the new
constructor is not advised, as this is typically done by a call through
the `Kernel`. The `new SKContext` constructor is planned to be hidden in
a future release.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#dev-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄

---------

Co-authored-by: Dmytro Struk <[email protected]>
  • Loading branch information
shawncal and dmytrostruk authored Jul 20, 2023
1 parent f1fb50f commit 9dd8604
Show file tree
Hide file tree
Showing 21 changed files with 128 additions and 324 deletions.
12 changes: 6 additions & 6 deletions dotnet/samples/KernelSyntaxExamples/Example15_MemorySkill.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ public static async Task RunAsync()
// ========= Store memories using semantic function =========

// Add Memory as a skill for other functions
var memorySkill = new TextMemorySkill();
kernel.ImportSkill(new TextMemorySkill());
var memorySkill = new TextMemorySkill(kernel.Memory);
kernel.ImportSkill(memorySkill);

// Build a semantic function that saves info to memory
const string SaveFunctionDefinition = "{{save $info}}";
Expand All @@ -48,7 +48,7 @@ public static async Task RunAsync()
// ========= Test memory remember =========
Console.WriteLine("========= Example: Recalling a Memory =========");

var answer = await memorySkill.RetrieveAsync(MemoryCollectionName, "info5", context);
var answer = await memorySkill.RetrieveAsync(MemoryCollectionName, "info5", logger: context.Log);
Console.WriteLine("Memory associated with 'info1': {0}", answer);
/*
Output:
Expand All @@ -58,11 +58,11 @@ public static async Task RunAsync()
// ========= Test memory recall =========
Console.WriteLine("========= Example: Recalling an Idea =========");

answer = await memorySkill.RecallAsync("where did I grow up?", MemoryCollectionName, relevance: null, limit: 2, context: context);
answer = await memorySkill.RecallAsync("where did I grow up?", MemoryCollectionName, relevance: null, limit: 2, logger: context.Log);
Console.WriteLine("Ask: where did I grow up?");
Console.WriteLine("Answer:\n{0}", answer);

answer = await memorySkill.RecallAsync("where do I live?", MemoryCollectionName, relevance: null, limit: 2, context: context);
answer = await memorySkill.RecallAsync("where do I live?", MemoryCollectionName, relevance: null, limit: 2, logger: context.Log);
Console.WriteLine("Ask: where do I live?");
Console.WriteLine("Answer:\n{0}", answer);

Expand Down Expand Up @@ -131,7 +131,7 @@ My name is Andrea and my family is from New York. I work as a tourist operator.
*/

context[TextMemorySkill.KeyParam] = "info1";
await memorySkill.RemoveAsync(MemoryCollectionName, "info1", context);
await memorySkill.RemoveAsync(MemoryCollectionName, "info1", logger: context.Log);

result = await aboutMeOracle.InvokeAsync("Tell me a bit about myself", context);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ public static async Task RunAsync()
IDictionary<string, ISKFunction> skills = LoadQASkill(kernel);
SKContext context = CreateContextQueryContext(kernel);

// Create a memory store using the VolatileMemoryStore and the embedding generator registered in the kernel
kernel.ImportSkill(new TextMemorySkill(kernel.Memory));

// Setup defined memories for recall
await RememberFactsAsync(kernel);

Expand Down Expand Up @@ -84,7 +87,7 @@ private static SKContext CreateContextQueryContext(IKernel kernel)

private static async Task RememberFactsAsync(IKernel kernel)
{
kernel.ImportSkill(new TextMemorySkill());
kernel.ImportSkill(new TextMemorySkill(kernel.Memory));

List<string> memoriesToSave = new()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,6 @@ public static Task RunAsync()
apiKey: azureOpenAIKey)
.Build();

// Example: how to use a custom memory storage and custom embedding generator
var kernel5 = Kernel.Builder
.WithLogger(NullLogger.Instance)
.WithMemoryStorageAndTextEmbeddingGeneration(memoryStorage, textEmbeddingGenerator)
.Build();

// Example: how to use a custom memory storage
var kernel6 = Kernel.Builder
.WithLogger(NullLogger.Instance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Planning;
using Microsoft.SemanticKernel.SemanticFunctions;
Expand Down Expand Up @@ -99,14 +97,11 @@ public async Task MalformedJsonThrowsAsync()
await Assert.ThrowsAsync<PlanningException>(async () => await planner.CreatePlanAsync("goal"));
}

private Mock<IKernel> CreateMockKernelAndFunctionFlowWithTestString(string testPlanString, Mock<ISkillCollection>? mockSkills = null)
private Mock<IKernel> CreateMockKernelAndFunctionFlowWithTestString(string testPlanString, Mock<ISkillCollection>? skills = null)
{
var kernel = new Mock<IKernel>();

var memory = new Mock<ISemanticTextMemory>();
var skills = mockSkills;

if (mockSkills == null)
if (skills is null)
{
skills = new Mock<ISkillCollection>();

Expand All @@ -116,16 +111,11 @@ private Mock<IKernel> CreateMockKernelAndFunctionFlowWithTestString(string testP

var returnContext = new SKContext(
new ContextVariables(testPlanString),
memory.Object,
skills!.Object,
new Mock<ILogger>().Object
skills.Object
);

var context = new SKContext(
new ContextVariables(),
memory.Object,
skills!.Object,
new Mock<ILogger>().Object
skills: skills.Object
);

var mockFunctionFlowFunction = new Mock<ISKFunction>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public async Task CanCallGetAvailableFunctionsWithNoFunctionsAsync()
.Returns(asyncEnumerable);

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(variables, memory.Object, skills.ReadOnlySkillCollection, logger, cancellationToken);
var config = new SequentialPlannerConfig();
var context = new SKContext(variables, skills.ReadOnlySkillCollection, logger, cancellationToken);
var config = new SequentialPlannerConfig() { Memory = memory.Object };
var semanticQuery = "test";

// Act
Expand Down Expand Up @@ -98,8 +98,8 @@ public async Task CanCallGetAvailableFunctionsWithFunctionsAsync()
skills.SetupGet(x => x.ReadOnlySkillCollection).Returns(skills.Object);

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(variables, memory.Object, skills.Object, logger, cancellationToken);
var config = new SequentialPlannerConfig();
var context = new SKContext(variables, skills.Object, logger, cancellationToken);
var config = new SequentialPlannerConfig() { Memory = memory.Object };
var semanticQuery = "test";

// Act
Expand Down Expand Up @@ -164,8 +164,8 @@ public async Task CanCallGetAvailableFunctionsWithFunctionsWithRelevancyAsync()
skills.SetupGet(x => x.ReadOnlySkillCollection).Returns(skills.Object);

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(variables, memory.Object, skills.Object, logger, cancellationToken);
var config = new SequentialPlannerConfig { RelevancyThreshold = 0.78 };
var context = new SKContext(variables, skills.Object, logger, cancellationToken);
var config = new SequentialPlannerConfig { RelevancyThreshold = 0.78, Memory = memory.Object };
var semanticQuery = "test";

// Act
Expand Down Expand Up @@ -217,8 +217,8 @@ public async Task CanCallGetAvailableFunctionsAsyncWithDefaultRelevancyAsync()
.Returns(asyncEnumerable);

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(variables, memory.Object, skills.ReadOnlySkillCollection, logger, cancellationToken);
var config = new SequentialPlannerConfig { RelevancyThreshold = 0.78 };
var context = new SKContext(variables, skills.ReadOnlySkillCollection, logger, cancellationToken);
var config = new SequentialPlannerConfig { RelevancyThreshold = 0.78, Memory = memory.Object };
var semanticQuery = "test";

// Act
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private SKContext CreateSKContext(
ContextVariables? variables = null,
CancellationToken cancellationToken = default)
{
return new SKContext(variables, kernel.Memory, kernel.Skills, kernel.Log, cancellationToken);
return new SKContext(variables, kernel.Skills, kernel.Log, cancellationToken);
}

private static Mock<ISKFunction> CreateMockFunction(FunctionView functionView, string result = "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI.TextCompletion;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Planning;
using Microsoft.SemanticKernel.SemanticFunctions;
Expand All @@ -27,8 +26,6 @@ public async Task ItCanCreatePlanAsync(string goal)
var kernel = new Mock<IKernel>();
kernel.Setup(x => x.Log).Returns(new Mock<ILogger>().Object);

var memory = new Mock<ISemanticTextMemory>();

var input = new List<(string name, string skillName, string description, bool isSemantic)>()
{
("SendEmail", "email", "Send an e-mail", false),
Expand Down Expand Up @@ -66,14 +63,12 @@ public async Task ItCanCreatePlanAsync(string goal)

var context = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
new Mock<ILogger>().Object
);

var returnContext = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
new Mock<ILogger>().Object
);
Expand Down Expand Up @@ -153,8 +148,6 @@ public async Task InvalidXMLThrowsAsync()
{
// Arrange
var kernel = new Mock<IKernel>();
// kernel.Setup(x => x.Log).Returns(new Mock<ILogger>().Object);
var memory = new Mock<ISemanticTextMemory>();
var skills = new Mock<ISkillCollection>();

var functionsView = new FunctionsView();
Expand All @@ -163,14 +156,12 @@ public async Task InvalidXMLThrowsAsync()
var planString = "<plan>notvalid<</plan>";
var returnContext = new SKContext(
new ContextVariables(planString),
memory.Object,
skills.Object,
new Mock<ILogger>().Object
);

var context = new SKContext(
new ContextVariables(),
memory.Object,
skills.Object,
new Mock<ILogger>().Object
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public static async Task<IOrderedEnumerable<FunctionView>> GetAvailableFunctions
.ToList();

List<FunctionView>? result = null;
if (string.IsNullOrEmpty(semanticQuery) || context.Memory is NullMemory || config.RelevancyThreshold is null)
if (string.IsNullOrEmpty(semanticQuery) || config.Memory is NullMemory || config.RelevancyThreshold is null)
{
// If no semantic query is provided, return all available functions.
// If a Memory provider has not been registered, return all available functions.
Expand All @@ -85,10 +85,10 @@ public static async Task<IOrderedEnumerable<FunctionView>> GetAvailableFunctions
result = new List<FunctionView>();

// Remember functions in memory so that they can be searched.
await RememberFunctionsAsync(context, availableFunctions).ConfigureAwait(false);
await RememberFunctionsAsync(context, config.Memory, availableFunctions).ConfigureAwait(false);

// Search for functions that match the semantic query.
var memories = context.Memory.SearchAsync(PlannerMemoryCollectionName, semanticQuery!, config.MaxRelevantFunctions, config.RelevancyThreshold.Value,
var memories = config.Memory.SearchAsync(PlannerMemoryCollectionName, semanticQuery!, config.MaxRelevantFunctions, config.RelevancyThreshold.Value,
false,
context.CancellationToken);

Expand Down Expand Up @@ -130,8 +130,12 @@ public static async Task<IEnumerable<FunctionView>> GetRelevantFunctionsAsync(SK
/// Saves all available functions to memory.
/// </summary>
/// <param name="context">The SKContext to save the functions to.</param>
/// <param name="memory">The memory provide to store the functions to..</param>
/// <param name="availableFunctions">The available functions to save.</param>
internal static async Task RememberFunctionsAsync(SKContext context, List<FunctionView> availableFunctions)
internal static async Task RememberFunctionsAsync(
SKContext context,
ISemanticTextMemory memory,
List<FunctionView> availableFunctions)
{
// Check if the functions have already been saved to memory.
if (context.Variables.ContainsKey(PlanSKFunctionsAreRemembered))
Expand All @@ -147,14 +151,14 @@ internal static async Task RememberFunctionsAsync(SKContext context, List<Functi
var textToEmbed = function.ToEmbeddingString();

// It'd be nice if there were a saveIfNotExists method on the memory interface
var memoryEntry = await context.Memory.GetAsync(collection: PlannerMemoryCollectionName, key: key, withEmbedding: false,
var memoryEntry = await memory.GetAsync(collection: PlannerMemoryCollectionName, key: key, withEmbedding: false,
cancellationToken: context.CancellationToken).ConfigureAwait(false);
if (memoryEntry == null)
{
// TODO It'd be nice if the minRelevanceScore could be a parameter for each item that was saved to memory
// As folks may want to tune their functions to be more or less relevant.
// Memory now supports these such strategies.
await context.Memory.SaveInformationAsync(collection: PlannerMemoryCollectionName, text: textToEmbed, id: key, description: description,
await memory.SaveInformationAsync(collection: PlannerMemoryCollectionName, text: textToEmbed, id: key, description: description,
additionalMetadata: string.Empty, cancellationToken: context.CancellationToken).ConfigureAwait(false);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.SkillDefinition;

namespace Microsoft.SemanticKernel.Planning.Sequential;
Expand Down Expand Up @@ -61,6 +62,11 @@ public sealed class SequentialPlannerConfig
/// </summary>
public bool AllowMissingFunctions { get; set; } = false;

/// <summary>
/// Semantic memory to use for function lookup (optional).
/// </summary>
public ISemanticTextMemory Memory { get; set; } = NullMemory.Instance;

/// <summary>
/// Optional callback to get the available functions for planning.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public async Task CreatePlanGoalRelevantAsync(string prompt, string expectedFunc
TestHelpers.ImportSampleSkills(kernel);

var planner = new Microsoft.SemanticKernel.Planning.SequentialPlanner(kernel,
new SequentialPlannerConfig { RelevancyThreshold = 0.65, MaxRelevantFunctions = 30 });
new SequentialPlannerConfig { RelevancyThreshold = 0.65, MaxRelevantFunctions = 30, Memory = kernel.Memory });

// Act
var plan = await planner.CreatePlanAsync(prompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
using Microsoft.SemanticKernel.Skills.Core;
using Microsoft.SemanticKernel.Skills.Web;
using Microsoft.SemanticKernel.Skills.Web.Bing;
using SemanticKernel.IntegrationTests.Fakes;
using SemanticKernel.IntegrationTests.TestSettings;
using Xunit;
using Xunit.Abstractions;
Expand Down Expand Up @@ -133,8 +132,6 @@ private IKernel InitializeKernel(bool useEmbeddings = false, bool useChatModel =

var kernel = builder.Build();

_ = kernel.ImportSkill(new EmailSkillFake());

return kernel;
}

Expand Down
16 changes: 6 additions & 10 deletions dotnet/src/SemanticKernel.Abstractions/Orchestration/SKContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ public SKContext Fail(string errorDescription, Exception? exception = null)
/// <summary>
/// Semantic memory
/// </summary>
public ISemanticTextMemory Memory { get; }
[Obsolete("Memory no longer passed through SKContext. Instead, initialize your skill class with the memory provider it needs.")]
public ISemanticTextMemory Memory
{
get => throw new InvalidOperationException(
"Memory no longer passed through SKContext. Instead, initialize your skill class with the memory provider it needs.");
}

/// <summary>
/// Read only skills collection
Expand Down Expand Up @@ -132,19 +137,16 @@ public ISKFunction Func(string skillName, string functionName)
/// Constructor for the context.
/// </summary>
/// <param name="variables">Context variables to include in context.</param>
/// <param name="memory">Semantic text memory unit to include in context.</param>
/// <param name="skills">Skills to include in context.</param>
/// <param name="logger">Logger for operations in context.</param>
/// <param name="cancellationToken">Optional cancellation token for operations in context.</param>
public SKContext(
ContextVariables? variables = null,
ISemanticTextMemory? memory = null,
IReadOnlySkillCollection? skills = null,
ILogger? logger = null,
CancellationToken cancellationToken = default)
{
this.Variables = variables ?? new();
this.Memory = memory ?? NullMemory.Instance;
this.Skills = skills ?? NullReadOnlySkillCollection.Instance;
this.Log = logger ?? NullLogger.Instance;
this.CancellationToken = cancellationToken;
Expand All @@ -170,7 +172,6 @@ public SKContext Clone()
{
return new SKContext(
variables: this.Variables.Clone(),
memory: this.Memory,
skills: this.Skills,
logger: this.Log,
cancellationToken: this.CancellationToken)
Expand Down Expand Up @@ -200,11 +201,6 @@ private string DebuggerDisplay
display += $", Skills = {view.NativeFunctions.Count + view.SemanticFunctions.Count}";
}

if (this.Memory is ISemanticTextMemory memory && memory is not NullMemory)
{
display += $", Memory = {memory.GetType().Name}";
}

display += $", Culture = {this.Culture.EnglishName}";

return display;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
<ItemGroup>
<InternalsVisibleTo Include="Microsoft.SemanticKernel.Core" />
<InternalsVisibleTo Include="SemanticKernel.UnitTests" />
<InternalsVisibleTo Include="Extensions.UnitTests" />
<InternalsVisibleTo Include="DynamicProxyGenAssembly2" /> <!-- Moq -->
</ItemGroup>
</Project>
Loading

0 comments on commit 9dd8604

Please sign in to comment.