diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs
index a2c771f208f9..aab5f90ffab6 100644
--- a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs
+++ b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs
@@ -41,6 +41,16 @@ public interface IAgentRuntime
/// A task that represents the asynchronous operation.
ValueTask RuntimeSendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default);
+ ///
+ /// Sends a message directly to another agent.
+ ///
+ /// The message to be sent.
+ /// The recipient of the message.
+ /// The agent sending the message.
+ /// A token to cancel the operation.
+ /// A task that represents the response to th message.
+ ValueTask SendMessageAsync(IMessage message, AgentId recipient, AgentId? sender, CancellationToken? cancellationToken = default);
+
///
/// Publishes a message to a topic.
///
diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
index baf16f31981f..1a505be6c3b7 100644
--- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
@@ -4,6 +4,7 @@
using System.Collections.Concurrent;
using System.Reflection;
using System.Threading.Channels;
+using Google.Protobuf;
using Grpc.Core;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.DependencyInjection;
@@ -16,19 +17,20 @@ public sealed class GrpcAgentRuntime(
AgentRpc.AgentRpcClient client,
IHostApplicationLifetime hostApplicationLifetime,
IServiceProvider serviceProvider,
- [FromKeyedServices("AgentTypes")] IEnumerable> configuredAgentTypes,
+ [FromKeyedServices("AgentTypes")] IEnumerable> configuredAgentTypes,
ILogger logger
) : AgentRuntime(
hostApplicationLifetime,
serviceProvider,
- configuredAgentTypes
+ configuredAgentTypes,
+ logger
), IDisposable
{
private readonly object _channelLock = new();
- private readonly ConcurrentDictionary _agentTypes = new();
+ private readonly ConcurrentDictionary _agentTypes = new();
private readonly ConcurrentDictionary<(string Type, string Key), Agent> _agents = new();
private readonly ConcurrentDictionary _pendingRequests = new();
- private readonly ConcurrentDictionary> _agentsForEvent = new();
+ private readonly ConcurrentDictionary> _agentsForEvent = new();
private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024)
{
AllowSynchronousContinuations = true,
@@ -38,8 +40,8 @@ ILogger logger
});
private readonly AgentRpc.AgentRpcClient _client = client;
public readonly IServiceProvider ServiceProvider = serviceProvider;
- private readonly IEnumerable> _configuredAgentTypes = configuredAgentTypes;
- private readonly ILogger _logger = logger;
+ private readonly IEnumerable> _configuredAgentTypes = configuredAgentTypes;
+ private new readonly ILogger _logger = logger;
private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping);
private AsyncDuplexStreamingCall? _channel;
private Task? _readTask;
@@ -192,7 +194,7 @@ private async Task RunWritePump()
item.WriteCompletionSource.TrySetCanceled();
}
}
- private Agent GetOrActivateAgent(AgentId agentId)
+ private new Agent GetOrActivateAgent(AgentId agentId)
{
if (!_agents.TryGetValue((agentId.Type, agentId.Key), out var agent))
{
@@ -290,6 +292,18 @@ await WriteChannelAsync(new Message
}
}
}
+ public override async ValueTask SendMessageAsync(IMessage message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default)
+ {
+ var request = new RpcRequest
+ {
+ RequestId = Guid.NewGuid().ToString(),
+ Source = agent,
+ Target = agentId,
+ Payload = (Payload)message,
+ };
+ var response = await InvokeRequestAsync(request).ConfigureAwait(false);
+ return response;
+ }
// new is intentional
public new async ValueTask RuntimeSendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default)
{
diff --git a/dotnet/src/Microsoft.AutoGen/Core/Agent.cs b/dotnet/src/Microsoft.AutoGen/Core/Agent.cs
index ee0abaad280a..d9c2f3fea40b 100644
--- a/dotnet/src/Microsoft.AutoGen/Core/Agent.cs
+++ b/dotnet/src/Microsoft.AutoGen/Core/Agent.cs
@@ -422,7 +422,7 @@ private Task CallHandlerAsync(CloudEvent item, CancellationToken cancellationTok
return Task.CompletedTask;
}
- public Task HandleRequestAsync(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
+ public virtual Task HandleRequestAsync(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" });
///
/// Handles a generic object
diff --git a/dotnet/src/Microsoft.AutoGen/Core/AgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core/AgentRuntime.cs
index 3d824daa8da6..7da485fa874f 100644
--- a/dotnet/src/Microsoft.AutoGen/Core/AgentRuntime.cs
+++ b/dotnet/src/Microsoft.AutoGen/Core/AgentRuntime.cs
@@ -1,9 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentRuntime.cs
using System.Collections.Concurrent;
+using System.Threading.Channels;
+using Google.Protobuf;
+using Google.Protobuf.WellKnownTypes;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
namespace Microsoft.AutoGen.Core;
@@ -19,24 +23,41 @@ namespace Microsoft.AutoGen.Core;
public class AgentRuntime(
IHostApplicationLifetime hostApplicationLifetime,
IServiceProvider serviceProvider,
- [FromKeyedServices("AgentTypes")] IEnumerable> configuredAgentTypes) :
+ [FromKeyedServices("AgentTypes")] IEnumerable> configuredAgentTypes,
+ ILogger logger) :
AgentRuntimeBase(
hostApplicationLifetime,
serviceProvider,
- configuredAgentTypes)
+ configuredAgentTypes,
+ logger)
{
private readonly ConcurrentDictionary _agentStates = new();
private readonly ConcurrentDictionary> _subscriptionsByAgentType = new();
private readonly ConcurrentDictionary> _subscriptionsByTopic = new();
private readonly ConcurrentDictionary> _subscriptionsByGuid = new();
private readonly IRegistry _registry = serviceProvider.GetRequiredService();
+ private readonly ConcurrentDictionary> _pendingRequests = new();
+ private static readonly TimeSpan s_agentResponseTimeout = TimeSpan.FromSeconds(300);
+ private new readonly ILogger _logger = logger;
///
public override async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken cancellationToken = default)
{
await _registry.RegisterAgentTypeAsync(request, this);
}
-
+ ///
+ public override async ValueTask SendMessageAsync(IMessage message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default)
+ {
+ var request = new RpcRequest
+ {
+ RequestId = Guid.NewGuid().ToString(),
+ Source = agent,
+ Target = agentId,
+ Payload = new Payload { Data = Any.Pack(message).ToByteString() }
+ };
+ var response = await InvokeRequestAsync(request).ConfigureAwait(false);
+ return response;
+ }
///
public override ValueTask SaveStateAsync(AgentState value, CancellationToken cancellationToken = default)
{
@@ -142,7 +163,7 @@ public override async ValueTask RemoveSubscriptionAs
return response;
}
- public ValueTask> GetSubscriptionsAsync(Type type)
+ public ValueTask> GetSubscriptionsAsync(System.Type type)
{
if (_subscriptionsByAgentType.TryGetValue(type.Name, out var subscriptions))
{
@@ -159,4 +180,57 @@ public override ValueTask> GetSubscriptionsAsync(GetSubscript
}
return new ValueTask>(subscriptions);
}
+ public override async ValueTask DispatchRequestAsync(RpcRequest request)
+ {
+ var requestId = request.RequestId;
+ if (request.Target is null)
+ {
+ throw new InvalidOperationException($"Request message is missing a target. Message: '{request}'.");
+ }
+ var agentId = request.Target;
+ await InvokeRequestDelegate(_mailbox, request, async request =>
+ {
+ return await InvokeRequestAsync(request).ConfigureAwait(true);
+ }).ConfigureAwait(false);
+ }
+ public override void DispatchResponse(RpcResponse response)
+ {
+ if (!_pendingRequests.TryRemove(response.RequestId, out var completion))
+ {
+ _logger.LogWarning("Received response for unknown request id: {RequestId}.", response.RequestId);
+ return;
+ }
+ // Complete the request.
+ completion.SetResult(response);
+ }
+ public async ValueTask InvokeRequestAsync(RpcRequest request, CancellationToken cancellationToken = default)
+ {
+ var agentId = request.Target;
+ // get the agent
+ var agent = GetOrActivateAgent(agentId);
+
+ // Proxy the request to the agent.
+ var originalRequestId = request.RequestId;
+ var completion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ _pendingRequests.TryAdd(request.RequestId, completion);
+ //request.RequestId = Guid.NewGuid().ToString();
+ agent.ReceiveMessage(new Message() { Request = request });
+ // Wait for the response and send it back to the caller.
+ var response = await completion.Task.WaitAsync(s_agentResponseTimeout);
+ response.RequestId = originalRequestId;
+ return response;
+ }
+ private static async Task InvokeRequestDelegate(Channel