Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rysweet 5217 add send message #5219

Merged
merged 3 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ public interface IAgentRuntime
/// <returns>A task that represents the asynchronous operation.</returns>
ValueTask RuntimeSendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default);

/// <summary>
/// Sends a message directly to another agent.
/// </summary>
/// <param name="message">The message to be sent.</param>
/// <param name="recipient">The recipient of the message.</param>
/// <param name="sender">The agent sending the message.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task that represents the response to th message.</returns>
ValueTask<RpcResponse> SendMessageAsync(IMessage message, AgentId recipient, AgentId? sender, CancellationToken? cancellationToken = default);

/// <summary>
/// Publishes a message to a topic.
/// </summary>
Expand Down
28 changes: 21 additions & 7 deletions dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Concurrent;
using System.Reflection;
using System.Threading.Channels;
using Google.Protobuf;
using Grpc.Core;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.DependencyInjection;
Expand All @@ -16,19 +17,20 @@ public sealed class GrpcAgentRuntime(
AgentRpc.AgentRpcClient client,
IHostApplicationLifetime hostApplicationLifetime,
IServiceProvider serviceProvider,
[FromKeyedServices("AgentTypes")] IEnumerable<Tuple<string, Type>> configuredAgentTypes,
[FromKeyedServices("AgentTypes")] IEnumerable<Tuple<string, System.Type>> configuredAgentTypes,
ILogger<GrpcAgentRuntime> logger
) : AgentRuntime(
hostApplicationLifetime,
serviceProvider,
configuredAgentTypes
configuredAgentTypes,
logger
), IDisposable
{
private readonly object _channelLock = new();
private readonly ConcurrentDictionary<string, Type> _agentTypes = new();
private readonly ConcurrentDictionary<string, global::System.Type> _agentTypes = new();
private readonly ConcurrentDictionary<(string Type, string Key), Agent> _agents = new();
private readonly ConcurrentDictionary<string, (Agent Agent, string OriginalRequestId)> _pendingRequests = new();
private readonly ConcurrentDictionary<string, HashSet<Type>> _agentsForEvent = new();
private readonly ConcurrentDictionary<string, HashSet<global::System.Type>> _agentsForEvent = new();
private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024)
{
AllowSynchronousContinuations = true,
Expand All @@ -38,8 +40,8 @@ ILogger<GrpcAgentRuntime> logger
});
private readonly AgentRpc.AgentRpcClient _client = client;
public readonly IServiceProvider ServiceProvider = serviceProvider;
private readonly IEnumerable<Tuple<string, Type>> _configuredAgentTypes = configuredAgentTypes;
private readonly ILogger<GrpcAgentRuntime> _logger = logger;
private readonly IEnumerable<Tuple<string, System.Type>> _configuredAgentTypes = configuredAgentTypes;
private new readonly ILogger<GrpcAgentRuntime> _logger = logger;
private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping);
private AsyncDuplexStreamingCall<Message, Message>? _channel;
private Task? _readTask;
Expand Down Expand Up @@ -192,7 +194,7 @@ private async Task RunWritePump()
item.WriteCompletionSource.TrySetCanceled();
}
}
private Agent GetOrActivateAgent(AgentId agentId)
private new Agent GetOrActivateAgent(AgentId agentId)
{
if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent))
{
Expand Down Expand Up @@ -290,6 +292,18 @@ await WriteChannelAsync(new Message
}
}
}
public override async ValueTask<RpcResponse> SendMessageAsync(IMessage message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default)
{
var request = new RpcRequest
{
RequestId = Guid.NewGuid().ToString(),
Source = agent,
Target = agentId,
Payload = (Payload)message,
};
var response = await InvokeRequestAsync(request).ConfigureAwait(false);
return response;
}
// new is intentional
public new async ValueTask RuntimeSendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default)
{
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Microsoft.AutoGen/Core/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ private Task CallHandlerAsync(CloudEvent item, CancellationToken cancellationTok
return Task.CompletedTask;
}

public Task<RpcResponse> HandleRequestAsync(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
public virtual Task<RpcResponse> HandleRequestAsync(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });

/// <summary>
/// Handles a generic object
Expand Down
82 changes: 78 additions & 4 deletions dotnet/src/Microsoft.AutoGen/Core/AgentRuntime.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentRuntime.cs
using System.Collections.Concurrent;
using System.Threading.Channels;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;

namespace Microsoft.AutoGen.Core;

Expand All @@ -19,24 +23,41 @@ namespace Microsoft.AutoGen.Core;
public class AgentRuntime(
IHostApplicationLifetime hostApplicationLifetime,
IServiceProvider serviceProvider,
[FromKeyedServices("AgentTypes")] IEnumerable<Tuple<string, Type>> configuredAgentTypes) :
[FromKeyedServices("AgentTypes")] IEnumerable<Tuple<string, System.Type>> configuredAgentTypes,
ILogger<AgentRuntime> logger) :
AgentRuntimeBase(
hostApplicationLifetime,
serviceProvider,
configuredAgentTypes)
configuredAgentTypes,
logger)
{
private readonly ConcurrentDictionary<string, AgentState> _agentStates = new();
private readonly ConcurrentDictionary<string, List<Subscription>> _subscriptionsByAgentType = new();
private readonly ConcurrentDictionary<string, List<string>> _subscriptionsByTopic = new();
private readonly ConcurrentDictionary<Guid, IDictionary<string, string>> _subscriptionsByGuid = new();
private readonly IRegistry _registry = serviceProvider.GetRequiredService<IRegistry>();
private readonly ConcurrentDictionary<string, TaskCompletionSource<RpcResponse>> _pendingRequests = new();
private static readonly TimeSpan s_agentResponseTimeout = TimeSpan.FromSeconds(300);
private new readonly ILogger<AgentRuntime> _logger = logger;

/// <inheritdoc />
public override async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken cancellationToken = default)
{
await _registry.RegisterAgentTypeAsync(request, this);
}

/// <inheritdoc />
public override async ValueTask<RpcResponse> SendMessageAsync(IMessage message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default)
{
var request = new RpcRequest
{
RequestId = Guid.NewGuid().ToString(),
Source = agent,
Target = agentId,
Payload = new Payload { Data = Any.Pack(message).ToByteString() }
};
var response = await InvokeRequestAsync(request).ConfigureAwait(false);
return response;
}
/// <inheritdoc />
public override ValueTask SaveStateAsync(AgentState value, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -142,7 +163,7 @@ public override async ValueTask<RemoveSubscriptionResponse> RemoveSubscriptionAs
return response;
}

public ValueTask<List<Subscription>> GetSubscriptionsAsync(Type type)
public ValueTask<List<Subscription>> GetSubscriptionsAsync(System.Type type)
{
if (_subscriptionsByAgentType.TryGetValue(type.Name, out var subscriptions))
{
Expand All @@ -159,4 +180,57 @@ public override ValueTask<List<Subscription>> GetSubscriptionsAsync(GetSubscript
}
return new ValueTask<List<Subscription>>(subscriptions);
}
public override async ValueTask DispatchRequestAsync(RpcRequest request)
{
var requestId = request.RequestId;
if (request.Target is null)
{
throw new InvalidOperationException($"Request message is missing a target. Message: '{request}'.");
}
var agentId = request.Target;
await InvokeRequestDelegate(_mailbox, request, async request =>
{
return await InvokeRequestAsync(request).ConfigureAwait(true);
}).ConfigureAwait(false);
}
public override void DispatchResponse(RpcResponse response)
{
if (!_pendingRequests.TryRemove(response.RequestId, out var completion))
{
_logger.LogWarning("Received response for unknown request id: {RequestId}.", response.RequestId);
return;
}
// Complete the request.
completion.SetResult(response);
}
public async ValueTask<RpcResponse> InvokeRequestAsync(RpcRequest request, CancellationToken cancellationToken = default)
{
var agentId = request.Target;
// get the agent
var agent = GetOrActivateAgent(agentId);

// Proxy the request to the agent.
var originalRequestId = request.RequestId;
var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
_pendingRequests.TryAdd(request.RequestId, completion);
//request.RequestId = Guid.NewGuid().ToString();
agent.ReceiveMessage(new Message() { Request = request });
// Wait for the response and send it back to the caller.
var response = await completion.Task.WaitAsync(s_agentResponseTimeout);
response.RequestId = originalRequestId;
return response;
}
private static async Task InvokeRequestDelegate(Channel<object> mailbox, RpcRequest request, Func<RpcRequest, Task<RpcResponse>> func)
{
try
{
var response = await func(request);
response.RequestId = request.RequestId;
await mailbox.Writer.WriteAsync(new Message { Response = response }).ConfigureAwait(false);
}
catch (Exception ex)
{
await mailbox.Writer.WriteAsync(new Message { Response = new RpcResponse { RequestId = request.RequestId, Error = ex.Message } }).ConfigureAwait(false);
}
}
}
17 changes: 15 additions & 2 deletions dotnet/src/Microsoft.AutoGen/Core/AgentRuntimeBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;

namespace Microsoft.AutoGen.Core;
/// <summary>
Expand All @@ -19,7 +20,8 @@ namespace Microsoft.AutoGen.Core;
public abstract class AgentRuntimeBase(
IHostApplicationLifetime hostApplicationLifetime,
IServiceProvider serviceProvider,
[FromKeyedServices("AgentTypes")] IEnumerable<Tuple<string, Type>> configuredAgentTypes) : IHostedService, IAgentRuntime
[FromKeyedServices("AgentTypes")] IEnumerable<Tuple<string, Type>> configuredAgentTypes,
ILogger<AgentRuntimeBase> logger) : IHostedService, IAgentRuntime
{
public IServiceProvider RuntimeServiceProvider { get; } = serviceProvider;
protected readonly ConcurrentDictionary<string, (Agent Agent, string OriginalRequestId)> _pendingClientRequests = new();
Expand All @@ -31,6 +33,7 @@ public abstract class AgentRuntimeBase(
private readonly IEnumerable<Tuple<string, Type>> _configuredAgentTypes = configuredAgentTypes;
private Task? _mailboxTask;
private readonly object _channelLock = new();
protected readonly ILogger<AgentRuntimeBase> _logger = logger;

/// <summary>
/// Starts the agent runtime.
Expand All @@ -40,6 +43,7 @@ public abstract class AgentRuntimeBase(
/// <returns>Task</returns>
public async Task StartAsync(CancellationToken cancellationToken)
{
_logger.LogInformation("Starting AgentRuntimeBase...");
StartCore();

foreach (var (typeName, type) in _configuredAgentTypes)
Expand Down Expand Up @@ -102,6 +106,12 @@ public async Task RunMessagePump()
if (message == null) { continue; }
switch (message)
{
case Message msg when msg.Request != null:
await DispatchRequestAsync(msg.Request).ConfigureAwait(true);
break;
case Message msg when msg.Response != null:
DispatchResponse(msg.Response);
break;
case Message msg when msg.CloudEvent != null:

var item = msg.CloudEvent;
Expand Down Expand Up @@ -132,6 +142,8 @@ public async Task RunMessagePump()
}
}
}
public abstract void DispatchResponse(RpcResponse response);
public abstract ValueTask DispatchRequestAsync(RpcRequest request);
public abstract ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken cancellationToken = default);
public async ValueTask PublishMessageAsync(IMessage message, TopicId topic, IAgent? sender, CancellationToken? cancellationToken = default)
{
Expand All @@ -145,7 +157,7 @@ public async ValueTask PublishMessageAsync(IMessage message, TopicId topic, IAge
public abstract ValueTask<RemoveSubscriptionResponse> RemoveSubscriptionAsync(RemoveSubscriptionRequest request, CancellationToken cancellationToken = default);
public abstract ValueTask<List<Subscription>> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken cancellationToken = default);

private Agent GetOrActivateAgent(AgentId agentId)
protected Agent GetOrActivateAgent(AgentId agentId)
{
if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent))
{
Expand Down Expand Up @@ -206,4 +218,5 @@ private async ValueTask DispatchEventsToAgentsAsync(CloudEvent cloudEvent, Cance
public abstract ValueTask RuntimeSendRequestAsync(IAgent agent, RpcRequest request, CancellationToken cancellationToken = default);

public abstract ValueTask RuntimeSendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default);
public abstract ValueTask<RpcResponse> SendMessageAsync(IMessage message, AgentId recipient, AgentId? sender, CancellationToken? cancellationToken = default);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public class UninitializedAgentWorker() : IAgentRuntime
public ValueTask<RemoveSubscriptionResponse> RemoveSubscriptionAsync(RemoveSubscriptionRequest request, CancellationToken cancellationToken = default) => throw new AgentInitalizedIncorrectlyException(AgentNotInitializedMessage);
public ValueTask PublishMessageAsync(IMessage message, TopicId topic, IAgent? sender, CancellationToken? cancellationToken = null) => throw new AgentInitalizedIncorrectlyException(AgentNotInitializedMessage);
public ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken cancellationToken = default) => throw new AgentInitalizedIncorrectlyException(AgentNotInitializedMessage);
public ValueTask<RpcResponse> SendMessageAsync(IMessage message, AgentId recipient, AgentId? sender, CancellationToken? cancellationToken = null) => throw new AgentInitalizedIncorrectlyException(AgentNotInitializedMessage);
public class AgentInitalizedIncorrectlyException(string message) : Exception(message)
{
}
Expand Down
74 changes: 74 additions & 0 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentRuntimeTests.cs

using Google.Protobuf;
using Google.Protobuf.Reflection;
using Google.Protobuf.WellKnownTypes;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Moq;
using Xunit;

namespace Microsoft.AutoGen.Core.Tests;

public class AgentRuntimeTests
{
private readonly Mock<IHostApplicationLifetime> _hostApplicationLifetimeMock;
private readonly Mock<IServiceProvider> _serviceProviderMock;
private readonly Mock<ILogger<AgentRuntime>> _loggerMock;
private readonly Mock<IRegistry> _registryMock;
private readonly AgentRuntime _agentRuntime;

public AgentRuntimeTests()
{
_hostApplicationLifetimeMock = new Mock<IHostApplicationLifetime>();
_serviceProviderMock = new Mock<IServiceProvider>();
_loggerMock = new Mock<ILogger<AgentRuntime>>();
_registryMock = new Mock<IRegistry>();

_serviceProviderMock.Setup(sp => sp.GetService(typeof(IRegistry))).Returns(_registryMock.Object);

var configuredAgentTypes = new List<Tuple<string, System.Type>>
{
new Tuple<string, System.Type>("TestAgent", typeof(TestAgent))
};

_agentRuntime = new AgentRuntime(
_hostApplicationLifetimeMock.Object,
_serviceProviderMock.Object,
configuredAgentTypes,
_loggerMock.Object);
}

[Fact]
public async Task SendMessageAsync_ShouldReturnResponse()
{
// Arrange
var fixture = new InMemoryAgentRuntimeFixture();
var (runtime, agent) = fixture.Start();
var agentId = new AgentId { Type = "TestAgent", Key = "test-key" };
var message = new TextMessage { TextMessage_ = "Hello, World!" };
var agentMock = new Mock<TestAgent>(MockBehavior.Loose, new AgentsMetadata(TypeRegistry.Empty, new Dictionary<string, System.Type>(), new Dictionary<System.Type, HashSet<string>>(), new Dictionary<System.Type, HashSet<string>>()), new Logger<Agent>(new LoggerFactory()));
agentMock.CallBase = true; // Enable calling the base class methods
agentMock.Setup(a => a.HandleObjectAsync(It.IsAny<object>(), It.IsAny<CancellationToken>())).Callback<object, CancellationToken>((msg, ct) =>
{
var response = new RpcResponse
{
RequestId = "test-request-id",
Payload = new Payload { Data = Any.Pack(new TextMessage { TextMessage_ = "Response" }).ToByteString() }
};
_agentRuntime.DispatchResponse(response);
});

// Act
var response = await runtime.SendMessageAsync(message, agentId, agent.AgentId);

// Assert
Assert.NotNull(response);
var any = Any.Parser.ParseFrom(response.Payload.Data);
var unpackedMessage = any.Unpack<TextMessage>();
Assert.Equal("Response", unpackedMessage.TextMessage_);
fixture.Stop();
}
}
Loading
Loading