Skip to content

Commit

Permalink
Extract ICallInfo interfaces (#641)
Browse files Browse the repository at this point in the history
Apply review comments.
  • Loading branch information
dtchepak committed Mar 28, 2021
1 parent 035b939 commit ab6733c
Show file tree
Hide file tree
Showing 18 changed files with 152 additions and 116 deletions.
26 changes: 13 additions & 13 deletions src/NSubstitute/Callback.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class Callback
/// </summary>
/// <param name="doThis"></param>
/// <returns></returns>
public static ConfiguredCallback First(Action<CallInfo> doThis)
public static ConfiguredCallback First(Action<ICallInfo> doThis)
{
return new ConfiguredCallback().Then(doThis);
}
Expand All @@ -28,7 +28,7 @@ public static ConfiguredCallback First(Action<CallInfo> doThis)
/// </summary>
/// <param name="doThis"></param>
/// <returns></returns>
public static Callback Always(Action<CallInfo> doThis)
public static Callback Always(Action<ICallInfo> doThis)
{
return new ConfiguredCallback().AndAlways(doThis);
}
Expand All @@ -38,7 +38,7 @@ public static Callback Always(Action<CallInfo> doThis)
/// </summary>
/// <param name="throwThis"></param>
/// <returns></returns>
public static ConfiguredCallback FirstThrow<TException>(Func<CallInfo, TException> throwThis) where TException : Exception
public static ConfiguredCallback FirstThrow<TException>(Func<ICallInfo, TException> throwThis) where TException : Exception
{
return new ConfiguredCallback().ThenThrow(throwThis);
}
Expand All @@ -59,7 +59,7 @@ public static ConfiguredCallback FirstThrow<TException>(TException exception) wh
/// <typeparam name="TException">The type of the exception.</typeparam>
/// <param name="throwThis">The throw this.</param>
/// <returns></returns>
public static Callback AlwaysThrow<TException>(Func<CallInfo, TException> throwThis) where TException : Exception
public static Callback AlwaysThrow<TException>(Func<ICallInfo, TException> throwThis) where TException : Exception
{
return new ConfiguredCallback().AndAlways(ToCallback(throwThis));
}
Expand All @@ -75,33 +75,33 @@ public static Callback AlwaysThrow<TException>(TException exception) where TExce
return AlwaysThrow(_ => exception);
}

protected static Action<CallInfo> ToCallback<TException>(Func<CallInfo, TException> throwThis)
protected static Action<ICallInfo> ToCallback<TException>(Func<ICallInfo, TException> throwThis)
where TException : notnull, Exception
{
return ci => { if (throwThis != null) throw throwThis(ci); };
}

internal Callback() { }
private readonly ConcurrentQueue<Action<CallInfo>> callbackQueue = new ConcurrentQueue<Action<CallInfo>>();
private Action<CallInfo> alwaysDo = x => { };
private Action<CallInfo> keepDoing = x => { };
private readonly ConcurrentQueue<Action<ICallInfo>> callbackQueue = new ConcurrentQueue<Action<ICallInfo>>();
private Action<ICallInfo> alwaysDo = x => { };
private Action<ICallInfo> keepDoing = x => { };

protected void AddCallback(Action<CallInfo> doThis)
protected void AddCallback(Action<ICallInfo> doThis)
{
callbackQueue.Enqueue(doThis);
}

protected void SetAlwaysDo(Action<CallInfo> always)
protected void SetAlwaysDo(Action<ICallInfo> always)
{
alwaysDo = always ?? (_ => { });
}

protected void SetKeepDoing(Action<CallInfo> keep)
protected void SetKeepDoing(Action<ICallInfo> keep)
{
keepDoing = keep ?? (_ => { });
}

public void Call(CallInfo callInfo)
public void Call(ICallInfo callInfo)
{
try
{
Expand All @@ -113,7 +113,7 @@ public void Call(CallInfo callInfo)
}
}

private void CallFromStack(CallInfo callInfo)
private void CallFromStack(ICallInfo callInfo)
{
if (callbackQueue.TryDequeue(out var callback))
{
Expand Down
10 changes: 5 additions & 5 deletions src/NSubstitute/Callbacks/ConfiguredCallback.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ internal ConfiguredCallback() { }
/// <summary>
/// Perform this action once in chain of called callbacks.
/// </summary>
public ConfiguredCallback Then(Action<CallInfo> doThis)
public ConfiguredCallback Then(Action<ICallInfo> doThis)
{
AddCallback(doThis);
return this;
Expand All @@ -22,7 +22,7 @@ public ConfiguredCallback Then(Action<CallInfo> doThis)
/// <summary>
/// Keep doing this action after the other callbacks have run.
/// </summary>
public EndCallbackChain ThenKeepDoing(Action<CallInfo> doThis)
public EndCallbackChain ThenKeepDoing(Action<ICallInfo> doThis)
{
SetKeepDoing(doThis);
return this;
Expand All @@ -31,7 +31,7 @@ public EndCallbackChain ThenKeepDoing(Action<CallInfo> doThis)
/// <summary>
/// Keep throwing this exception after the other callbacks have run.
/// </summary>
public EndCallbackChain ThenKeepThrowing<TException>(Func<CallInfo, TException> throwThis) where TException : Exception =>
public EndCallbackChain ThenKeepThrowing<TException>(Func<ICallInfo, TException> throwThis) where TException : Exception =>
ThenKeepDoing(ToCallback(throwThis));

/// <summary>
Expand All @@ -45,7 +45,7 @@ public EndCallbackChain ThenKeepThrowing<TException>(TException throwThis) where
/// </summary>
/// <typeparam name="TException">The type of the exception</typeparam>
/// <param name="throwThis">Produce the exception to throw for a CallInfo</param>
public ConfiguredCallback ThenThrow<TException>(Func<CallInfo, TException> throwThis) where TException : Exception
public ConfiguredCallback ThenThrow<TException>(Func<ICallInfo, TException> throwThis) where TException : Exception
{
AddCallback(ToCallback(throwThis));
return this;
Expand All @@ -68,7 +68,7 @@ internal EndCallbackChain() { }
/// Perform the given action for every call.
/// </summary>
/// <param name="doThis">The action to perform for every call</param>
public Callback AndAlways(Action<CallInfo> doThis)
public Callback AndAlways(Action<ICallInfo> doThis)
{
SetAlwaysDo(doThis);
return this;
Expand Down
82 changes: 24 additions & 58 deletions src/NSubstitute/Core/CallInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

namespace NSubstitute.Core
{
public class CallInfo
public class CallInfo : ICallInfo
{
private readonly Argument[] _callArguments;
private readonly Func<Maybe<object>> _baseResult;

public CallInfo(Argument[] callArguments, Func<Maybe<object>> baseResult)
{
public CallInfo(Argument[] callArguments, Func<Maybe<object>> baseResult) {
_callArguments = callArguments;
_baseResult = baseResult;
}
Expand All @@ -33,63 +32,41 @@ protected object GetBaseResult() {
return _baseResult().ValueOr(() => throw new NoBaseImplementationException());
}

/// <summary>
/// Gets the nth argument to this call.
/// </summary>
/// <param name="index">Index of argument</param>
/// <returns>The value of the argument at the given index</returns>
public object this[int index]
{
/// <inheritdoc/>
public object this[int index] {
get => _callArguments[index].Value;
set
{
set {
var argument = _callArguments[index];
EnsureArgIsSettable(argument, index, value);
argument.Value = value;
}
}

private void EnsureArgIsSettable(Argument argument, int index, object value)
{
if (!argument.IsByRef)
{
private void EnsureArgIsSettable(Argument argument, int index, object value) {
if (!argument.IsByRef) {
throw new ArgumentIsNotOutOrRefException(index, argument.DeclaredType);
}

if (value != null && !argument.CanSetValueWithInstanceOf(value.GetType()))
{
if (value != null && !argument.CanSetValueWithInstanceOf(value.GetType())) {
throw new ArgumentSetWithIncompatibleValueException(index, argument.DeclaredType, value.GetType());
}
}

/// <summary>
/// Get the arguments passed to this call.
/// </summary>
/// <returns>Array of all arguments passed to this call</returns>
/// <inheritdoc/>
public object[] Args() => _callArguments.Select(x => x.Value).ToArray();

/// <summary>
/// Gets the types of all the arguments passed to this call.
/// </summary>
/// <returns>Array of types of all arguments passed to this call</returns>
/// <inheritdoc/>
public Type[] ArgTypes() => _callArguments.Select(x => x.DeclaredType).ToArray();

/// <summary>
/// Gets the argument of type `T` passed to this call. This will throw if there are no arguments
/// of this type, or if there is more than one matching argument.
/// </summary>
/// <typeparam name="T">The type of the argument to retrieve</typeparam>
/// <returns>The argument passed to the call, or throws if there is not exactly one argument of this type</returns>
public T Arg<T>()
{
/// <inheritdoc/>
public T Arg<T>() {
T arg;
if (TryGetArg(x => x.IsDeclaredTypeEqualToOrByRefVersionOf(typeof(T)), out arg)) return arg;
if (TryGetArg(x => x.IsValueAssignableTo(typeof(T)), out arg)) return arg;
throw new ArgumentNotFoundException("Can not find an argument of type " + typeof(T).FullName + " to this call.");
}

private bool TryGetArg<T>(Func<Argument, bool> condition, [MaybeNullWhen(false)] out T value)
{
private bool TryGetArg<T>(Func<Argument, bool> condition, [MaybeNullWhen(false)] out T value) {
value = default;

var matchingArgs = _callArguments.Where(condition);
Expand All @@ -100,10 +77,8 @@ private bool TryGetArg<T>(Func<Argument, bool> condition, [MaybeNullWhen(false)]
return true;
}

private void ThrowIfMoreThanOne<T>(IEnumerable<Argument> arguments)
{
if (arguments.Skip(1).Any())
{
private void ThrowIfMoreThanOne<T>(IEnumerable<Argument> arguments) {
if (arguments.Skip(1).Any()) {
throw new AmbiguousArgumentsException(
"There is more than one argument of type " + typeof(T).FullName + " to this call.\n" +
"The call signature is (" + DisplayTypes(ArgTypes()) + ")\n" +
Expand All @@ -112,33 +87,24 @@ private void ThrowIfMoreThanOne<T>(IEnumerable<Argument> arguments)
}
}

/// <summary>
/// Gets the argument passed to this call at the specified zero-based position, converted to type `T`.
/// This will throw if there are no arguments, if the argument is out of range or if it
/// cannot be converted to the specified type.
/// </summary>
/// <typeparam name="T">The type of the argument to retrieve</typeparam>
/// <param name="position">The zero-based position of the argument to retrieve</param>
/// <returns>The argument passed to the call, or throws if there is not exactly one argument of this type</returns>
public T ArgAt<T>(int position)
{
if (position >= _callArguments.Length)
{
/// <inheritdoc/>
public T ArgAt<T>(int position) {
if (position >= _callArguments.Length) {
throw new ArgumentOutOfRangeException(nameof(position), $"There is no argument at position {position}");
}

try
{
return (T) _callArguments[position].Value!;
}
catch (InvalidCastException)
{
try {
return (T)_callArguments[position].Value!;
} catch (InvalidCastException) {
throw new InvalidCastException(
$"Couldn't convert parameter at position {position} to type {typeof(T).FullName}");
}
}

private static string DisplayTypes(IEnumerable<Type> types) =>
string.Join(", ", types.Select(x => x.Name).ToArray());

/// <inheritdoc/>
public ICallInfo<T> ForCallReturning<T>() => new CallInfo<T>(this);
}
}
2 changes: 1 addition & 1 deletion src/NSubstitute/Core/CallInfoWithReturns.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
/// Information for a call that returns a value of type <c>T</c>.
/// </summary>
/// <typeparam name="T"></typeparam>
public class CallInfo<T> : CallInfo
public class CallInfo<T> : CallInfo, ICallInfo<T>
{
internal CallInfo(CallInfo info) : base(info) {
}
Expand Down
32 changes: 17 additions & 15 deletions src/NSubstitute/Core/IReturn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ namespace NSubstitute.Core
{
public interface IReturn
{
object? ReturnFor(CallInfo info);
object? ReturnFor(ICallInfo info);
Type? TypeOrNull();
bool CanBeAssignedTo(Type t);
}

/// <summary>
/// Performance optimization. Allows to not construct <see cref="CallInfo"/> if configured result doesn't depend on it.
/// Performance optimization. Allows to not construct <see cref="ICallInfo"/> if configured result doesn't depend on it.
/// </summary>
internal interface ICallIndependentReturn
{
Expand All @@ -32,25 +32,25 @@ public ReturnValue(object? value)
}

public object? GetReturnValue() => _value;
public object? ReturnFor(CallInfo info) => GetReturnValue();
public object? ReturnFor(ICallInfo info) => GetReturnValue();
public Type? TypeOrNull() => _value?.GetType();
public bool CanBeAssignedTo(Type t) => _value.IsCompatibleWith(t);
}

public class ReturnValueFromFunc<T> : IReturn
{
private readonly Func<CallInfo<T>, T?> _funcToReturnValue;
private readonly Func<ICallInfo<T>, T?> _funcToReturnValue;

public ReturnValueFromFunc(Func<CallInfo<T>, T?>? funcToReturnValue)
public ReturnValueFromFunc(Func<ICallInfo<T>, T?>? funcToReturnValue)
{
_funcToReturnValue = funcToReturnValue ?? ReturnNull();
}

public object? ReturnFor(CallInfo info) => _funcToReturnValue(new CallInfo<T>(info));
public object? ReturnFor(ICallInfo info) => _funcToReturnValue(info.ForCallReturning<T>());
public Type TypeOrNull() => typeof(T);
public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t);

private static Func<CallInfo, T?> ReturnNull()
private static Func<ICallInfo, T?> ReturnNull()
{
if (typeof(T).GetTypeInfo().IsValueType) throw new CannotReturnNullForValueType(typeof(T));
return x => default;
Expand All @@ -69,7 +69,7 @@ public ReturnMultipleValues(T?[] values)
}

public object? GetReturnValue() => GetNext();
public object? ReturnFor(CallInfo info) => GetReturnValue();
public object? ReturnFor(ICallInfo info) => GetReturnValue();
public Type TypeOrNull() => typeof(T);
public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t);

Expand All @@ -78,20 +78,22 @@ public ReturnMultipleValues(T?[] values)

public class ReturnMultipleFuncsValues<T> : IReturn
{
private readonly ConcurrentQueue<Func<CallInfo<T>, T?>> _funcsToReturn;
private readonly Func<CallInfo<T>, T?> _lastFunc;
private readonly ConcurrentQueue<Func<ICallInfo<T>, T?>> _funcsToReturn;
private readonly Func<ICallInfo<T>, T?> _lastFunc;

public ReturnMultipleFuncsValues(Func<CallInfo<T>, T?>[] funcs)
public ReturnMultipleFuncsValues(Func<ICallInfo<T>, T?>[] funcs)
{
_funcsToReturn = new ConcurrentQueue<Func<CallInfo<T>, T?>>(funcs);
_funcsToReturn = new ConcurrentQueue<Func<ICallInfo<T>, T?>>(funcs);
_lastFunc = funcs.Last();
}

public object? ReturnFor(CallInfo info) => GetNext(info);
public object? ReturnFor(ICallInfo info) => GetNext(info);
public Type TypeOrNull() => typeof(T);
public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t);

private T? GetNext(CallInfo info) =>
_funcsToReturn.TryDequeue(out var nextFunc) ? nextFunc(new CallInfo<T>(info)) : _lastFunc(new CallInfo<T>(info));
private T? GetNext(ICallInfo info) =>
_funcsToReturn.TryDequeue(out var nextFunc)
? nextFunc(info.ForCallReturning<T>())
: _lastFunc(info.ForCallReturning<T>());
}
}
4 changes: 2 additions & 2 deletions src/NSubstitute/Core/WhenCalled.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public WhenCalled(ISubstitutionContext context, T substitute, Action<T> call, Ma
/// Perform this action when called.
/// </summary>
/// <param name="callbackWithArguments"></param>
public void Do(Action<CallInfo> callbackWithArguments)
public void Do(Action<ICallInfo> callbackWithArguments)
{
_threadContext.SetNextRoute(_callRouter, x => _routeFactory.DoWhenCalled(x, callbackWithArguments, _matchArgs));
_call(_substitute);
Expand Down Expand Up @@ -82,7 +82,7 @@ public void Throw(Exception exception) =>
/// <summary>
/// Throw an exception generated by the specified function when called.
/// </summary>
public void Throw(Func<CallInfo, Exception> createException) =>
public void Throw(Func<ICallInfo, Exception> createException) =>
Do(ci => throw createException(ci));
}
}
Loading

0 comments on commit ab6733c

Please sign in to comment.