diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index a09f75a5be..f31a0b2ae4 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -282,16 +282,21 @@ private MemberDeclarationSyntax GenerateSetTargetMethod( var holderParameter = holder.Identifier; var containingInterface = methodDescription.ContainingInterface; + var targetType = containingInterface.ToTypeSyntax(); var isExtension = methodDescription.Key.ProxyBase.IsExtension; - var getTarget = InvocationExpression( + var (name, args) = isExtension switch + { + true => ("GetComponent", SingletonSeparatedList(Argument(TypeOfExpression(targetType)))), + _ => ("GetTarget", SeparatedList()) + }; + var getTarget = CastExpression( + targetType, + InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, holder, - GenericName(isExtension ? "GetComponent" : "GetTarget") - .WithTypeArgumentList( - TypeArgumentList( - SingletonSeparatedList(containingInterface.ToTypeSyntax()))))) - .WithArgumentList(ArgumentList()); + IdentifierName(name)), + ArgumentList(args))); var body = AssignmentExpression( @@ -305,7 +310,7 @@ private MemberDeclarationSyntax GenerateSetTargetMethod( .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); } - private MemberDeclarationSyntax GenerateGetTargetMethod(TargetFieldDescription targetField) + private static MethodDeclarationSyntax GenerateGetTargetMethod(TargetFieldDescription targetField) { return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetTarget") .WithParameterList(ParameterList()) diff --git a/src/Orleans.Core/Runtime/ClientGrainContext.cs b/src/Orleans.Core/Runtime/ClientGrainContext.cs index e228c7d212..ec42e4f61b 100644 --- a/src/Orleans.Core/Runtime/ClientGrainContext.cs +++ b/src/Orleans.Core/Runtime/ClientGrainContext.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -45,34 +46,31 @@ public ClientGrainContext(OutsideRuntimeClient runtimeClient) public bool Equals(IGrainContext other) => ReferenceEquals(this, other); - public TComponent GetComponent() where TComponent : class + public TComponent GetComponent() where TComponent : class => (TComponent)GetComponent(typeof(TComponent)); + public object GetComponent(Type componentType) { - if (this is TComponent component) return component; - if (_components.TryGetValue(typeof(TComponent), out var result)) + if (componentType.IsAssignableFrom(GetType())) return this; + if (_components.TryGetValue(componentType, out var result)) { - return (TComponent)result; + return result; } - else if (typeof(TComponent) == typeof(PlacementStrategy)) + else if (componentType == typeof(PlacementStrategy)) { - return (TComponent)(object)ClientObserversPlacement.Instance; + return ClientObserversPlacement.Instance; } lock (_lockObj) { - if (ActivationServices.GetService() is { } activatedComponent) + if (ActivationServices.GetService(componentType) is { } activatedComponent) { - return (TComponent)_components.GetOrAdd(typeof(TComponent), activatedComponent); + return _components.GetOrAdd(componentType, activatedComponent); } } return default; } - public TTarget GetTarget() where TTarget : class - { - if (this is TTarget target) return target; - return default; - } + public object GetTarget() => this; public void SetComponent(TComponent instance) where TComponent : class { diff --git a/src/Orleans.Core/Runtime/InvokableObjectManager.cs b/src/Orleans.Core/Runtime/InvokableObjectManager.cs index 48f752750f..99133fcf36 100644 --- a/src/Orleans.Core/Runtime/InvokableObjectManager.cs +++ b/src/Orleans.Core/Runtime/InvokableObjectManager.cs @@ -134,17 +134,18 @@ void IGrainContext.SetComponent(TComponent value) where TComponent : _manager.rootGrainContext.SetComponent(value); } - public TComponent GetComponent() where TComponent : class + public TComponent GetComponent() where TComponent : class => (TComponent)GetComponent(typeof(TComponent)); + public object GetComponent(Type componentType) { - if (this.LocalObject.Target is TComponent component) + if (componentType.IsAssignableFrom(this.LocalObject.Target?.GetType())) { - return component; + return LocalObject.Target; } - return _manager.rootGrainContext.GetComponent(); + return _manager.rootGrainContext.GetComponent(componentType); } - public TTarget GetTarget() where TTarget : class => (TTarget)this.LocalObject.Target; + public object GetTarget() => this.LocalObject.Target; bool IEquatable.Equals(IGrainContext other) => ReferenceEquals(this, other); diff --git a/src/Orleans.Runtime/Catalog/ActivationData.cs b/src/Orleans.Runtime/Catalog/ActivationData.cs index 5bb3b41027..7ce570d4fb 100644 --- a/src/Orleans.Runtime/Catalog/ActivationData.cs +++ b/src/Orleans.Runtime/Catalog/ActivationData.cs @@ -226,17 +226,17 @@ private DehydrationContextHolder? DehydrationContext public TimeSpan CollectionAgeLimit => _shared.CollectionAgeLimit; - public TTarget? GetTarget() where TTarget : class => (TTarget?)GrainInstance; + public object? GetTarget() => GrainInstance; - TComponent? ITargetHolder.GetComponent() where TComponent : class + object? ITargetHolder.GetComponent(Type componentType) { - var result = GetComponent(); - if (result is null && typeof(IGrainExtension).IsAssignableFrom(typeof(TComponent))) + var result = GetComponent(componentType); + if (result is null && typeof(IGrainExtension).IsAssignableFrom(componentType)) { - var implementation = ActivationServices.GetKeyedService(typeof(TComponent)); - if (implementation is not TComponent typedResult) + var implementation = ActivationServices.GetKeyedService(componentType); + if (implementation is not { } typedResult) { - throw new GrainExtensionNotInstalledException($"No extension of type {typeof(TComponent)} is installed on this instance and no implementations are registered for automated install"); + throw new GrainExtensionNotInstalledException($"No extension of type {componentType} is installed on this instance and no implementations are registered for automated install"); } SetComponent(typedResult); @@ -246,29 +246,30 @@ private DehydrationContextHolder? DehydrationContext return result; } - public TComponent? GetComponent() where TComponent : class + public TComponent? GetComponent() where TComponent : class => (TComponent?)GetComponent(typeof(TComponent)); + public object? GetComponent(Type componentType) { - TComponent? result; - if (GrainInstance is TComponent grainResult) + object? result; + if (componentType.IsAssignableFrom(GrainInstance?.GetType())) { - result = grainResult; + result = GrainInstance; } - else if (this is TComponent contextResult) + else if (componentType.IsAssignableFrom(GetType())) { - result = contextResult; + result = this; } - else if (_extras is { } components && components.TryGetValue(typeof(TComponent), out var resultObj)) + else if (_extras is { } components && components.TryGetValue(componentType, out var resultObj)) { - result = (TComponent)resultObj; + result = resultObj; } - else if (ActivationServices.GetService() is { } component) + else if (ActivationServices.GetService(componentType) is { } component) { SetComponent(component); result = component; } else { - result = _shared.GetComponent(); + result = _shared.GetComponent(componentType); } return result; @@ -814,7 +815,7 @@ public async ValueTask DisposeAsync() try { - var activator = _shared.GetComponent(); + var activator = _shared.GetComponent(typeof(IGrainActivator)) as IGrainActivator; if (activator != null && GrainInstance is { } instance) { await activator.DisposeInstance(this, instance); diff --git a/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs b/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs index 35c40dbeda..415e549a6c 100644 --- a/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs +++ b/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs @@ -22,7 +22,7 @@ namespace Orleans.Runtime; public class GrainTypeSharedContext { private readonly IServiceProvider _serviceProvider; - private readonly Dictionary _components = new(); + private readonly Dictionary _components = []; private InternalGrainRuntime? _internalGrainRuntime; public GrainTypeSharedContext( @@ -98,23 +98,27 @@ private static TimeSpan GetCollectionAgeLimit(GrainType grainType, Type grainCla /// /// Gets a component. /// - /// The type specified in the corresponding call. - public TComponent? GetComponent() + public TComponent? GetComponent() where TComponent : class => GetComponent(typeof(TComponent)) as TComponent; + + /// + /// Gets a component. + /// + public object? GetComponent(Type componentType) { - if (typeof(TComponent) == typeof(PlacementStrategy) && PlacementStrategy is TComponent component) + if (componentType == typeof(PlacementStrategy) && PlacementStrategy is { } component) { return component; } if (_components is null) return default; - _components.TryGetValue(typeof(TComponent), out var resultObj); - return (TComponent?)resultObj; + _components.TryGetValue(componentType, out var resultObj); + return resultObj; } /// /// Registers a component. /// - /// The type which can be used as a key to . + /// The type which can be used as a key to . public void SetComponent(TComponent? instance) { if (instance == null) diff --git a/src/Orleans.Runtime/Catalog/StatelessWorkerGrainContext.cs b/src/Orleans.Runtime/Catalog/StatelessWorkerGrainContext.cs index f956ea8090..5960f7bd53 100644 --- a/src/Orleans.Runtime/Catalog/StatelessWorkerGrainContext.cs +++ b/src/Orleans.Runtime/Catalog/StatelessWorkerGrainContext.cs @@ -100,11 +100,11 @@ private void EnqueueWorkItem(WorkItemType type, object state) public bool Equals([AllowNull] IGrainContext other) => other is not null && ActivationId.Equals(other.ActivationId); - public TComponent? GetComponent() where TComponent : class => this switch + public object? GetComponent(Type componentType) { - TComponent contextResult => contextResult, - _ => _shared.GetComponent() - }; + if (componentType.IsAssignableFrom(GetType())) return this; + return _shared.GetComponent(componentType); + } public void SetComponent(TComponent? instance) where TComponent : class { @@ -116,7 +116,7 @@ public void SetComponent(TComponent? instance) where TComponent : cl _shared.SetComponent(instance); } - public TTarget GetTarget() where TTarget : class => throw new NotImplementedException(); + public object GetTarget() => throw new NotImplementedException(); private async Task RunMessageLoop() { diff --git a/src/Orleans.Runtime/Core/HostedClient.cs b/src/Orleans.Runtime/Core/HostedClient.cs index 04faef384c..ebc84c1d63 100644 --- a/src/Orleans.Runtime/Core/HostedClient.cs +++ b/src/Orleans.Runtime/Core/HostedClient.cs @@ -138,23 +138,24 @@ public void DeleteObjectReference(IAddressable obj) } } - public TComponent? GetComponent() where TComponent : class + public TComponent? GetComponent() where TComponent : class => (TComponent?)GetComponent(typeof(TComponent)); + public object? GetComponent(Type componentType) { - if (this is TComponent component) return component; - if (_components.TryGetValue(typeof(TComponent), out var result)) + if (componentType.IsAssignableFrom(GetType())) return this; + if (_components.TryGetValue(componentType, out var result)) { - return (TComponent)result; + return result; } - else if (typeof(TComponent) == typeof(PlacementStrategy)) + else if (componentType == typeof(PlacementStrategy)) { - return (TComponent)(object)ClientObserversPlacement.Instance; + return ClientObserversPlacement.Instance; } lock (lockObj) { - if (ActivationServices.GetService() is { } activatedComponent) + if (ActivationServices.GetService(componentType) is { } activatedComponent) { - return (TComponent)_components.GetOrAdd(typeof(TComponent), activatedComponent); + return _components.GetOrAdd(componentType, activatedComponent); } } @@ -204,7 +205,7 @@ public void ReceiveMessage(object message) if (msg.Direction == Message.Directions.Response) { - // Requests are made through the runtime client, so deliver responses to the rutnime client so that the request callback can be executed. + // Requests are made through the runtime client, so deliver responses to the runtime client so that the request callback can be executed. this.runtimeClient.ReceiveResponse(msg); } else @@ -383,7 +384,7 @@ public TExtensionInterface GetExtension() } } - public TTarget GetTarget() where TTarget : class => throw new NotImplementedException(); + public object? GetTarget() => throw new NotImplementedException(); public void Activate(Dictionary? requestContext, CancellationToken cancellationToken) { } public void Deactivate(DeactivationReason deactivationReason, CancellationToken cancellationToken) { } public Task Deactivated => Task.CompletedTask; diff --git a/src/Orleans.Runtime/Core/SystemTarget.cs b/src/Orleans.Runtime/Core/SystemTarget.cs index c06a9b8db0..f4f72b36a6 100644 --- a/src/Orleans.Runtime/Core/SystemTarget.cs +++ b/src/Orleans.Runtime/Core/SystemTarget.cs @@ -101,22 +101,21 @@ internal SystemTarget(SystemTargetGrainId grainId, SiloAddress silo, ILoggerFact /// /// Gets the component with the specified type. /// - /// The component type. /// The component with the specified type. - public TComponent GetComponent() + public object GetComponent(Type componentType) { - TComponent result; - if (this is TComponent instanceResult) + object result; + if (componentType.IsAssignableFrom(GetType())) { - result = instanceResult; + result = this; } - else if (_components.TryGetValue(typeof(TComponent), out var resultObj)) + else if (_components.TryGetValue(componentType, out var resultObj)) { - result = (TComponent)resultObj; + result = resultObj; } - else if (typeof(TComponent) == typeof(PlacementStrategy)) + else if (componentType == typeof(PlacementStrategy)) { - result = (TComponent)(object)SystemTargetPlacementStrategy.Instance; + result = SystemTargetPlacementStrategy.Instance; } else { @@ -224,25 +223,6 @@ bool ISpanFormattable.TryFormat(Span destination, out int charsWritten, Re return (implementation, reference); } - /// - TComponent ITargetHolder.GetComponent() - { - var result = this.GetComponent(); - if (result is null && typeof(IGrainExtension).IsAssignableFrom(typeof(TComponent))) - { - var implementation = this.ActivationServices.GetKeyedService(typeof(TComponent)); - if (implementation is not TComponent typedResult) - { - throw new GrainExtensionNotInstalledException($"No extension of type {typeof(TComponent)} is installed on this instance and no implementations are registered for automated install"); - } - - this.SetComponent(typedResult); - result = typedResult; - } - - return result; - } - /// public TExtensionInterface GetExtension() where TExtensionInterface : class, IGrainExtension @@ -284,7 +264,7 @@ public void ReceiveMessage(object message) } /// - public TTarget GetTarget() where TTarget : class => (TTarget)(object)this; + public object GetTarget()=> this; /// public void Activate(Dictionary requestContext, CancellationToken cancellationToken) { } diff --git a/src/Orleans.Serialization/Invocation/ITargetHolder.cs b/src/Orleans.Serialization/Invocation/ITargetHolder.cs index 7184b27899..105fe40a09 100644 --- a/src/Orleans.Serialization/Invocation/ITargetHolder.cs +++ b/src/Orleans.Serialization/Invocation/ITargetHolder.cs @@ -1,4 +1,6 @@ #nullable enable +using System; + namespace Orleans.Serialization.Invocation; /// @@ -7,16 +9,15 @@ namespace Orleans.Serialization.Invocation; public interface ITargetHolder { /// - /// Gets the target. + /// Gets the target instance. /// - /// The target type. /// The target. - TTarget? GetTarget() where TTarget : class; + object? GetTarget(); /// /// Gets the component with the specified type. /// - /// The component type. + /// The component type. /// The component with the specified type. - TComponent? GetComponent() where TComponent : class; -} \ No newline at end of file + object? GetComponent(Type componentType); +} diff --git a/src/Orleans.Serialization/Invocation/TargetHolderExtensions.cs b/src/Orleans.Serialization/Invocation/TargetHolderExtensions.cs new file mode 100644 index 0000000000..428051ba98 --- /dev/null +++ b/src/Orleans.Serialization/Invocation/TargetHolderExtensions.cs @@ -0,0 +1,23 @@ +#nullable enable +using Orleans.Serialization.Invocation; +namespace Orleans.Runtime; + +/// +/// Extension methods for . +/// +public static class TargetHolderExtensions +{ + /// + /// Gets the target with the specified type. + /// + /// The target type. + /// The target. + public static TTarget? GetTarget(this ITargetHolder targetHolder) where TTarget : class => targetHolder.GetTarget() as TTarget; + + /// + /// Gets the component with the specified type. + /// + /// The component type. + /// The component with the specified type. + public static TComponent? GetComponent(this ITargetHolder targetHolder) where TComponent : class => targetHolder.GetComponent(typeof(TComponent)) as TComponent; +} \ No newline at end of file diff --git a/test/NonSilo.Tests/SchedulerTests/OrleansTaskSchedulerBasicTests.cs b/test/NonSilo.Tests/SchedulerTests/OrleansTaskSchedulerBasicTests.cs index d388aaa23e..3a6249d771 100644 --- a/test/NonSilo.Tests/SchedulerTests/OrleansTaskSchedulerBasicTests.cs +++ b/test/NonSilo.Tests/SchedulerTests/OrleansTaskSchedulerBasicTests.cs @@ -1,12 +1,9 @@ using Microsoft.Extensions.Logging; -using Orleans.Runtime; using Orleans.Runtime.Scheduler; using UnitTests.TesterInternal; using Xunit; using Xunit.Abstractions; using Orleans.TestingHost.Utils; -using Orleans.Internal; -using Orleans; // ReSharper disable ConvertToConstant.Local @@ -53,8 +50,8 @@ private UnitTestSchedulingContext() { } public void Deactivate(DeactivationReason deactivationReason, CancellationToken cancellationToken) { } public Task Deactivated => Task.CompletedTask; public void Dispose() => (Scheduler as IDisposable)?.Dispose(); - public TComponent GetComponent() where TComponent : class => throw new NotImplementedException(); - public TTarget GetTarget() where TTarget : class => throw new NotImplementedException(); + public object GetComponent(Type componentType) => throw new NotImplementedException(); + public object GetTarget() => throw new NotImplementedException(); public void ReceiveMessage(object message) => throw new NotImplementedException(); public void SetComponent(TComponent value) where TComponent : class => throw new NotImplementedException();