diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs index a2c771f208f9..aab5f90ffab6 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs @@ -41,6 +41,16 @@ public interface IAgentRuntime /// A task that represents the asynchronous operation. ValueTask RuntimeSendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default); + /// + /// Sends a message directly to another agent. + /// + /// The message to be sent. + /// The recipient of the message. + /// The agent sending the message. + /// A token to cancel the operation. + /// A task that represents the response to th message. + ValueTask SendMessageAsync(IMessage message, AgentId recipient, AgentId? sender, CancellationToken? cancellationToken = default); + /// /// Publishes a message to a topic. /// diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs index baf16f31981f..1a505be6c3b7 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -4,6 +4,7 @@ using System.Collections.Concurrent; using System.Reflection; using System.Threading.Channels; +using Google.Protobuf; using Grpc.Core; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.DependencyInjection; @@ -16,19 +17,20 @@ public sealed class GrpcAgentRuntime( AgentRpc.AgentRpcClient client, IHostApplicationLifetime hostApplicationLifetime, IServiceProvider serviceProvider, - [FromKeyedServices("AgentTypes")] IEnumerable> configuredAgentTypes, + [FromKeyedServices("AgentTypes")] IEnumerable> configuredAgentTypes, ILogger logger ) : AgentRuntime( hostApplicationLifetime, serviceProvider, - configuredAgentTypes + configuredAgentTypes, + logger ), IDisposable { private readonly object _channelLock = new(); - private readonly ConcurrentDictionary _agentTypes = new(); + private readonly ConcurrentDictionary _agentTypes = new(); private readonly ConcurrentDictionary<(string Type, string Key), Agent> _agents = new(); private readonly ConcurrentDictionary _pendingRequests = new(); - private readonly ConcurrentDictionary> _agentsForEvent = new(); + private readonly ConcurrentDictionary> _agentsForEvent = new(); private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024) { AllowSynchronousContinuations = true, @@ -38,8 +40,8 @@ ILogger logger }); private readonly AgentRpc.AgentRpcClient _client = client; public readonly IServiceProvider ServiceProvider = serviceProvider; - private readonly IEnumerable> _configuredAgentTypes = configuredAgentTypes; - private readonly ILogger _logger = logger; + private readonly IEnumerable> _configuredAgentTypes = configuredAgentTypes; + private new readonly ILogger _logger = logger; private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping); private AsyncDuplexStreamingCall? _channel; private Task? _readTask; @@ -192,7 +194,7 @@ private async Task RunWritePump() item.WriteCompletionSource.TrySetCanceled(); } } - private Agent GetOrActivateAgent(AgentId agentId) + private new Agent GetOrActivateAgent(AgentId agentId) { if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent)) { @@ -290,6 +292,18 @@ await WriteChannelAsync(new Message } } } + public override async ValueTask SendMessageAsync(IMessage message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default) + { + var request = new RpcRequest + { + RequestId = Guid.NewGuid().ToString(), + Source = agent, + Target = agentId, + Payload = (Payload)message, + }; + var response = await InvokeRequestAsync(request).ConfigureAwait(false); + return response; + } // new is intentional public new async ValueTask RuntimeSendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/Microsoft.AutoGen/Core/Agent.cs b/dotnet/src/Microsoft.AutoGen/Core/Agent.cs index ee0abaad280a..d9c2f3fea40b 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/Agent.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/Agent.cs @@ -422,7 +422,7 @@ private Task CallHandlerAsync(CloudEvent item, CancellationToken cancellationTok return Task.CompletedTask; } - public Task HandleRequestAsync(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" }); + public virtual Task HandleRequestAsync(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" }); /// /// Handles a generic object diff --git a/dotnet/src/Microsoft.AutoGen/Core/AgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core/AgentRuntime.cs index 3d824daa8da6..7da485fa874f 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/AgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/AgentRuntime.cs @@ -1,9 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AgentRuntime.cs using System.Collections.Concurrent; +using System.Threading.Channels; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; namespace Microsoft.AutoGen.Core; @@ -19,24 +23,41 @@ namespace Microsoft.AutoGen.Core; public class AgentRuntime( IHostApplicationLifetime hostApplicationLifetime, IServiceProvider serviceProvider, - [FromKeyedServices("AgentTypes")] IEnumerable> configuredAgentTypes) : + [FromKeyedServices("AgentTypes")] IEnumerable> configuredAgentTypes, + ILogger logger) : AgentRuntimeBase( hostApplicationLifetime, serviceProvider, - configuredAgentTypes) + configuredAgentTypes, + logger) { private readonly ConcurrentDictionary _agentStates = new(); private readonly ConcurrentDictionary> _subscriptionsByAgentType = new(); private readonly ConcurrentDictionary> _subscriptionsByTopic = new(); private readonly ConcurrentDictionary> _subscriptionsByGuid = new(); private readonly IRegistry _registry = serviceProvider.GetRequiredService(); + private readonly ConcurrentDictionary> _pendingRequests = new(); + private static readonly TimeSpan s_agentResponseTimeout = TimeSpan.FromSeconds(300); + private new readonly ILogger _logger = logger; /// public override async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken cancellationToken = default) { await _registry.RegisterAgentTypeAsync(request, this); } - + /// + public override async ValueTask SendMessageAsync(IMessage message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default) + { + var request = new RpcRequest + { + RequestId = Guid.NewGuid().ToString(), + Source = agent, + Target = agentId, + Payload = new Payload { Data = Any.Pack(message).ToByteString() } + }; + var response = await InvokeRequestAsync(request).ConfigureAwait(false); + return response; + } /// public override ValueTask SaveStateAsync(AgentState value, CancellationToken cancellationToken = default) { @@ -142,7 +163,7 @@ public override async ValueTask RemoveSubscriptionAs return response; } - public ValueTask> GetSubscriptionsAsync(Type type) + public ValueTask> GetSubscriptionsAsync(System.Type type) { if (_subscriptionsByAgentType.TryGetValue(type.Name, out var subscriptions)) { @@ -159,4 +180,57 @@ public override ValueTask> GetSubscriptionsAsync(GetSubscript } return new ValueTask>(subscriptions); } + public override async ValueTask DispatchRequestAsync(RpcRequest request) + { + var requestId = request.RequestId; + if (request.Target is null) + { + throw new InvalidOperationException($"Request message is missing a target. Message: '{request}'."); + } + var agentId = request.Target; + await InvokeRequestDelegate(_mailbox, request, async request => + { + return await InvokeRequestAsync(request).ConfigureAwait(true); + }).ConfigureAwait(false); + } + public override void DispatchResponse(RpcResponse response) + { + if (!_pendingRequests.TryRemove(response.RequestId, out var completion)) + { + _logger.LogWarning("Received response for unknown request id: {RequestId}.", response.RequestId); + return; + } + // Complete the request. + completion.SetResult(response); + } + public async ValueTask InvokeRequestAsync(RpcRequest request, CancellationToken cancellationToken = default) + { + var agentId = request.Target; + // get the agent + var agent = GetOrActivateAgent(agentId); + + // Proxy the request to the agent. + var originalRequestId = request.RequestId; + var completion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _pendingRequests.TryAdd(request.RequestId, completion); + //request.RequestId = Guid.NewGuid().ToString(); + agent.ReceiveMessage(new Message() { Request = request }); + // Wait for the response and send it back to the caller. + var response = await completion.Task.WaitAsync(s_agentResponseTimeout); + response.RequestId = originalRequestId; + return response; + } + private static async Task InvokeRequestDelegate(Channel mailbox, RpcRequest request, Func> func) + { + try + { + var response = await func(request); + response.RequestId = request.RequestId; + await mailbox.Writer.WriteAsync(new Message { Response = response }).ConfigureAwait(false); + } + catch (Exception ex) + { + await mailbox.Writer.WriteAsync(new Message { Response = new RpcResponse { RequestId = request.RequestId, Error = ex.Message } }).ConfigureAwait(false); + } + } } diff --git a/dotnet/src/Microsoft.AutoGen/Core/AgentRuntimeBase.cs b/dotnet/src/Microsoft.AutoGen/Core/AgentRuntimeBase.cs index 6584e2b059b2..81ce479336cc 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/AgentRuntimeBase.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/AgentRuntimeBase.cs @@ -8,6 +8,7 @@ using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; namespace Microsoft.AutoGen.Core; /// @@ -19,7 +20,8 @@ namespace Microsoft.AutoGen.Core; public abstract class AgentRuntimeBase( IHostApplicationLifetime hostApplicationLifetime, IServiceProvider serviceProvider, - [FromKeyedServices("AgentTypes")] IEnumerable> configuredAgentTypes) : IHostedService, IAgentRuntime + [FromKeyedServices("AgentTypes")] IEnumerable> configuredAgentTypes, + ILogger logger) : IHostedService, IAgentRuntime { public IServiceProvider RuntimeServiceProvider { get; } = serviceProvider; protected readonly ConcurrentDictionary _pendingClientRequests = new(); @@ -31,6 +33,7 @@ public abstract class AgentRuntimeBase( private readonly IEnumerable> _configuredAgentTypes = configuredAgentTypes; private Task? _mailboxTask; private readonly object _channelLock = new(); + protected readonly ILogger _logger = logger; /// /// Starts the agent runtime. @@ -40,6 +43,7 @@ public abstract class AgentRuntimeBase( /// Task public async Task StartAsync(CancellationToken cancellationToken) { + _logger.LogInformation("Starting AgentRuntimeBase..."); StartCore(); foreach (var (typeName, type) in _configuredAgentTypes) @@ -102,6 +106,12 @@ public async Task RunMessagePump() if (message == null) { continue; } switch (message) { + case Message msg when msg.Request != null: + await DispatchRequestAsync(msg.Request).ConfigureAwait(true); + break; + case Message msg when msg.Response != null: + DispatchResponse(msg.Response); + break; case Message msg when msg.CloudEvent != null: var item = msg.CloudEvent; @@ -132,6 +142,8 @@ public async Task RunMessagePump() } } } + public abstract void DispatchResponse(RpcResponse response); + public abstract ValueTask DispatchRequestAsync(RpcRequest request); public abstract ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken cancellationToken = default); public async ValueTask PublishMessageAsync(IMessage message, TopicId topic, IAgent? sender, CancellationToken? cancellationToken = default) { @@ -145,7 +157,7 @@ public async ValueTask PublishMessageAsync(IMessage message, TopicId topic, IAge public abstract ValueTask RemoveSubscriptionAsync(RemoveSubscriptionRequest request, CancellationToken cancellationToken = default); public abstract ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken cancellationToken = default); - private Agent GetOrActivateAgent(AgentId agentId) + protected Agent GetOrActivateAgent(AgentId agentId) { if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent)) { @@ -206,4 +218,5 @@ private async ValueTask DispatchEventsToAgentsAsync(CloudEvent cloudEvent, Cance public abstract ValueTask RuntimeSendRequestAsync(IAgent agent, RpcRequest request, CancellationToken cancellationToken = default); public abstract ValueTask RuntimeSendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default); + public abstract ValueTask SendMessageAsync(IMessage message, AgentId recipient, AgentId? sender, CancellationToken? cancellationToken = default); } diff --git a/dotnet/src/Microsoft.AutoGen/Core/UninitializedAgentWorker.cs b/dotnet/src/Microsoft.AutoGen/Core/UninitializedAgentWorker.cs index 28d2c5f44ebc..5eb3c418bb9e 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/UninitializedAgentWorker.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/UninitializedAgentWorker.cs @@ -18,6 +18,7 @@ public class UninitializedAgentWorker() : IAgentRuntime public ValueTask RemoveSubscriptionAsync(RemoveSubscriptionRequest request, CancellationToken cancellationToken = default) => throw new AgentInitalizedIncorrectlyException(AgentNotInitializedMessage); public ValueTask PublishMessageAsync(IMessage message, TopicId topic, IAgent? sender, CancellationToken? cancellationToken = null) => throw new AgentInitalizedIncorrectlyException(AgentNotInitializedMessage); public ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken cancellationToken = default) => throw new AgentInitalizedIncorrectlyException(AgentNotInitializedMessage); + public ValueTask SendMessageAsync(IMessage message, AgentId recipient, AgentId? sender, CancellationToken? cancellationToken = null) => throw new AgentInitalizedIncorrectlyException(AgentNotInitializedMessage); public class AgentInitalizedIncorrectlyException(string message) : Exception(message) { } diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs new file mode 100644 index 000000000000..df3de5680599 --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AgentRuntimeTests.cs + +using Google.Protobuf; +using Google.Protobuf.Reflection; +using Google.Protobuf.WellKnownTypes; +using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Moq; +using Xunit; + +namespace Microsoft.AutoGen.Core.Tests; + +public class AgentRuntimeTests +{ + private readonly Mock _hostApplicationLifetimeMock; + private readonly Mock _serviceProviderMock; + private readonly Mock> _loggerMock; + private readonly Mock _registryMock; + private readonly AgentRuntime _agentRuntime; + + public AgentRuntimeTests() + { + _hostApplicationLifetimeMock = new Mock(); + _serviceProviderMock = new Mock(); + _loggerMock = new Mock>(); + _registryMock = new Mock(); + + _serviceProviderMock.Setup(sp => sp.GetService(typeof(IRegistry))).Returns(_registryMock.Object); + + var configuredAgentTypes = new List> + { + new Tuple("TestAgent", typeof(TestAgent)) + }; + + _agentRuntime = new AgentRuntime( + _hostApplicationLifetimeMock.Object, + _serviceProviderMock.Object, + configuredAgentTypes, + _loggerMock.Object); + } + + [Fact] + public async Task SendMessageAsync_ShouldReturnResponse() + { + // Arrange + var fixture = new InMemoryAgentRuntimeFixture(); + var (runtime, agent) = fixture.Start(); + var agentId = new AgentId { Type = "TestAgent", Key = "test-key" }; + var message = new TextMessage { TextMessage_ = "Hello, World!" }; + var agentMock = new Mock(MockBehavior.Loose, new AgentsMetadata(TypeRegistry.Empty, new Dictionary(), new Dictionary>(), new Dictionary>()), new Logger(new LoggerFactory())); + agentMock.CallBase = true; // Enable calling the base class methods + agentMock.Setup(a => a.HandleObjectAsync(It.IsAny(), It.IsAny())).Callback((msg, ct) => + { + var response = new RpcResponse + { + RequestId = "test-request-id", + Payload = new Payload { Data = Any.Pack(new TextMessage { TextMessage_ = "Response" }).ToByteString() } + }; + _agentRuntime.DispatchResponse(response); + }); + + // Act + var response = await runtime.SendMessageAsync(message, agentId, agent.AgentId); + + // Assert + Assert.NotNull(response); + var any = Any.Parser.ParseFrom(response.Payload.Data); + var unpackedMessage = any.Unpack(); + Assert.Equal("Response", unpackedMessage.TextMessage_); + fixture.Stop(); + } +} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs index 28b1334ba618..ce29188297a8 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs @@ -2,6 +2,8 @@ // TestAgent.cs using System.Collections.Concurrent; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; @@ -28,6 +30,15 @@ public Task Handle(int item) ReceivedItems.Add(item); return Task.CompletedTask; } + public override Task HandleRequestAsync(RpcRequest request) + { + var response = new RpcResponse + { + RequestId = request.RequestId, + Payload = new Payload { Data = Any.Pack(new TextMessage { TextMessage_ = "Response" }).ToByteString() } + }; + return Task.FromResult(response); + } public List ReceivedItems { get; private set; } = []; ///