Skip to content

Commit 7306ee7

Browse files
authored
Fix consume ValueTask backed by IValueTaskSource (#2108)
* Added AwaitHelper to properly wait for ValueTasks. * Adjust `AwaitHelper` to allow multiple threads to use it concurrently. * Changed AwaitHelper to static. * Add test case to make sure ValueTasks work properly with a race condition between `IsCompleted` and `OnCompleted`. Changed AwaitHelper to use `ManualResetEventSlim` instead of `Monitor.Wait`. * Make `ValueTaskWaiter.Wait` generic. * Compare types directly.
1 parent 0d30991 commit 7306ee7

File tree

10 files changed

+434
-188
lines changed

10 files changed

+434
-188
lines changed

src/BenchmarkDotNet/Code/DeclarationsProvider.cs

+5-9
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private string GetMethodName(MethodInfo method)
6363
(method.ReturnType.GetGenericTypeDefinition() == typeof(Task<>) ||
6464
method.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))))
6565
{
66-
return $"() => {method.Name}().GetAwaiter().GetResult()";
66+
return $"() => BenchmarkDotNet.Helpers.AwaitHelper.GetResult({method.Name}())";
6767
}
6868

6969
return method.Name;
@@ -149,12 +149,10 @@ internal class TaskDeclarationsProvider : VoidDeclarationsProvider
149149
{
150150
public TaskDeclarationsProvider(Descriptor descriptor) : base(descriptor) { }
151151

152-
// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
153-
// and will eventually throw actual exception, not aggregated one
154152
public override string WorkloadMethodDelegate(string passArguments)
155-
=> $"({passArguments}) => {{ {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
153+
=> $"({passArguments}) => {{ BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";
156154

157-
public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
155+
public override string GetWorkloadMethodCall(string passArguments) => $"BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";
158156

159157
protected override Type WorkloadMethodReturnType => typeof(void);
160158
}
@@ -168,11 +166,9 @@ public GenericTaskDeclarationsProvider(Descriptor descriptor) : base(descriptor)
168166

169167
protected override Type WorkloadMethodReturnType => Descriptor.WorkloadMethod.ReturnType.GetTypeInfo().GetGenericArguments().Single();
170168

171-
// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
172-
// and will eventually throw actual exception, not aggregated one
173169
public override string WorkloadMethodDelegate(string passArguments)
174-
=> $"({passArguments}) => {{ return {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
170+
=> $"({passArguments}) => {{ return BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";
175171

176-
public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
172+
public override string GetWorkloadMethodCall(string passArguments) => $"BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";
177173
}
178174
}
+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
using System;
2+
using System.Linq;
3+
using System.Reflection;
4+
using System.Runtime.CompilerServices;
5+
using System.Threading;
6+
using System.Threading.Tasks;
7+
8+
namespace BenchmarkDotNet.Helpers
9+
{
10+
public static class AwaitHelper
11+
{
12+
private class ValueTaskWaiter
13+
{
14+
// We use thread static field so that each thread uses its own individual callback and reset event.
15+
[ThreadStatic]
16+
private static ValueTaskWaiter ts_current;
17+
internal static ValueTaskWaiter Current => ts_current ??= new ValueTaskWaiter();
18+
19+
// We cache the callback to prevent allocations for memory diagnoser.
20+
private readonly Action awaiterCallback;
21+
private readonly ManualResetEventSlim resetEvent;
22+
23+
private ValueTaskWaiter()
24+
{
25+
resetEvent = new ();
26+
awaiterCallback = resetEvent.Set;
27+
}
28+
29+
internal void Wait<TAwaiter>(TAwaiter awaiter) where TAwaiter : ICriticalNotifyCompletion
30+
{
31+
resetEvent.Reset();
32+
awaiter.UnsafeOnCompleted(awaiterCallback);
33+
34+
// The fastest way to wait for completion is to spin a bit before waiting on the event. This is the same logic that Task.GetAwaiter().GetResult() uses.
35+
var spinner = new SpinWait();
36+
while (!resetEvent.IsSet)
37+
{
38+
if (spinner.NextSpinWillYield)
39+
{
40+
resetEvent.Wait();
41+
return;
42+
}
43+
spinner.SpinOnce();
44+
}
45+
}
46+
}
47+
48+
// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
49+
// and will eventually throw actual exception, not aggregated one
50+
public static void GetResult(Task task) => task.GetAwaiter().GetResult();
51+
52+
public static T GetResult<T>(Task<T> task) => task.GetAwaiter().GetResult();
53+
54+
// ValueTask can be backed by an IValueTaskSource that only supports asynchronous awaits,
55+
// so we have to hook up a callback instead of calling .GetAwaiter().GetResult() like we do for Task.
56+
// The alternative is to convert it to Task using .AsTask(), but that causes allocations which we must avoid for memory diagnoser.
57+
public static void GetResult(ValueTask task)
58+
{
59+
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
60+
var awaiter = task.ConfigureAwait(false).GetAwaiter();
61+
if (!awaiter.IsCompleted)
62+
{
63+
ValueTaskWaiter.Current.Wait(awaiter);
64+
}
65+
awaiter.GetResult();
66+
}
67+
68+
public static T GetResult<T>(ValueTask<T> task)
69+
{
70+
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
71+
var awaiter = task.ConfigureAwait(false).GetAwaiter();
72+
if (!awaiter.IsCompleted)
73+
{
74+
ValueTaskWaiter.Current.Wait(awaiter);
75+
}
76+
return awaiter.GetResult();
77+
}
78+
79+
internal static MethodInfo GetGetResultMethod(Type taskType)
80+
{
81+
if (!taskType.IsGenericType)
82+
{
83+
return typeof(AwaitHelper).GetMethod(nameof(AwaitHelper.GetResult), BindingFlags.Public | BindingFlags.Static, null, new Type[1] { taskType }, null);
84+
}
85+
86+
Type compareType = taskType.GetGenericTypeDefinition() == typeof(ValueTask<>) ? typeof(ValueTask<>)
87+
: typeof(Task).IsAssignableFrom(taskType.GetGenericTypeDefinition()) ? typeof(Task<>)
88+
: null;
89+
if (compareType == null)
90+
{
91+
return null;
92+
}
93+
var resultType = taskType
94+
.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)
95+
.ReturnType
96+
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)
97+
.ReturnType;
98+
return typeof(AwaitHelper).GetMethods(BindingFlags.Public | BindingFlags.Static)
99+
.First(m =>
100+
{
101+
if (m.Name != nameof(AwaitHelper.GetResult)) return false;
102+
Type paramType = m.GetParameters().First().ParameterType;
103+
return paramType.IsGenericType && paramType.GetGenericTypeDefinition() == compareType;
104+
})
105+
.MakeGenericMethod(new[] { resultType });
106+
}
107+
}
108+
}

src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/ConsumableTypeInfo.cs

+15-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using BenchmarkDotNet.Engines;
22
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
35
using System.Reflection;
46
using System.Runtime.CompilerServices;
57
using System.Threading.Tasks;
@@ -16,28 +18,24 @@ public ConsumableTypeInfo(Type methodReturnType)
1618

1719
OriginMethodReturnType = methodReturnType;
1820

19-
// Please note this code does not support await over extension methods.
20-
var getAwaiterMethod = methodReturnType.GetMethod(nameof(Task<int>.GetAwaiter), BindingFlagsPublicInstance);
21-
if (getAwaiterMethod == null)
21+
// Only support (Value)Task for parity with other toolchains (and so we can use AwaitHelper).
22+
IsAwaitable = methodReturnType == typeof(Task) || methodReturnType == typeof(ValueTask)
23+
|| (methodReturnType.GetTypeInfo().IsGenericType
24+
&& (methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(Task<>)
25+
|| methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(ValueTask<>)));
26+
27+
if (!IsAwaitable)
2228
{
2329
WorkloadMethodReturnType = methodReturnType;
2430
}
2531
else
2632
{
27-
var getResultMethod = getAwaiterMethod
33+
WorkloadMethodReturnType = methodReturnType
34+
.GetMethod(nameof(Task.GetAwaiter), BindingFlagsPublicInstance)
2835
.ReturnType
29-
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance);
30-
31-
if (getResultMethod == null)
32-
{
33-
WorkloadMethodReturnType = methodReturnType;
34-
}
35-
else
36-
{
37-
WorkloadMethodReturnType = getResultMethod.ReturnType;
38-
GetAwaiterMethod = getAwaiterMethod;
39-
GetResultMethod = getResultMethod;
40-
}
36+
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance)
37+
.ReturnType;
38+
GetResultMethod = Helpers.AwaitHelper.GetGetResultMethod(methodReturnType);
4139
}
4240

4341
if (WorkloadMethodReturnType == null)
@@ -74,14 +72,13 @@ public ConsumableTypeInfo(Type methodReturnType)
7472
public Type WorkloadMethodReturnType { get; }
7573
public Type OverheadMethodReturnType { get; }
7674

77-
public MethodInfo? GetAwaiterMethod { get; }
7875
public MethodInfo? GetResultMethod { get; }
7976

8077
public bool IsVoid { get; }
8178
public bool IsByRef { get; }
8279
public bool IsConsumable { get; }
8380
public FieldInfo? WorkloadConsumableField { get; }
8481

85-
public bool IsAwaitable => GetAwaiterMethod != null && GetResultMethod != null;
82+
public bool IsAwaitable { get; }
8683
}
8784
}

src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/RunnableEmitter.cs

+25-54
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ private void DefineFields()
434434

435435
Type argLocalsType;
436436
Type argFieldType;
437-
MethodInfo? opConversion = null;
437+
MethodInfo opConversion = null;
438438
if (parameterType.IsByRef)
439439
{
440440
argLocalsType = parameterType;
@@ -582,42 +582,28 @@ private MethodInfo EmitWorkloadImplementation(string methodName)
582582
workloadInvokeMethod.ReturnParameter,
583583
args);
584584
args = methodBuilder.GetEmitParameters(args);
585-
var callResultType = consumableInfo.OriginMethodReturnType;
586-
var awaiterType = consumableInfo.GetAwaiterMethod?.ReturnType
587-
?? throw new InvalidOperationException($"Bug: {nameof(consumableInfo.GetAwaiterMethod)} is null");
588585

589586
var ilBuilder = methodBuilder.GetILGenerator();
590587

591588
/*
592-
.locals init (
593-
[0] valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<int32>
594-
)
595-
*/
596-
var callResultLocal =
597-
ilBuilder.DeclareOptionalLocalForInstanceCall(callResultType, consumableInfo.GetAwaiterMethod);
598-
var awaiterLocal =
599-
ilBuilder.DeclareOptionalLocalForInstanceCall(awaiterType, consumableInfo.GetResultMethod);
600-
601-
/*
602-
// return TaskSample(arg0). ... ;
603-
IL_0000: ldarg.0
604-
IL_0001: ldarg.1
605-
IL_0002: call instance class [mscorlib]System.Threading.Tasks.Task`1<int32> [BenchmarkDotNet]BenchmarkDotNet.Samples.SampleBenchmark::TaskSample(int64)
606-
*/
589+
IL_0026: ldarg.0
590+
IL_0027: ldloc.0
591+
IL_0028: ldloc.1
592+
IL_0029: ldloc.2
593+
IL_002a: ldloc.3
594+
IL_002b: call instance class [System.Private.CoreLib]System.Threading.Tasks.Task`1<object> BenchmarkDotNet.Helpers.Runnable_0::WorkloadMethod(string, string, string, string)
595+
*/
607596
if (!Descriptor.WorkloadMethod.IsStatic)
608597
ilBuilder.Emit(OpCodes.Ldarg_0);
609598
ilBuilder.EmitLdargs(args);
610599
ilBuilder.Emit(OpCodes.Call, Descriptor.WorkloadMethod);
611600

612601
/*
613-
// ... .GetAwaiter().GetResult();
614-
IL_0007: callvirt instance valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<!0> class [mscorlib]System.Threading.Tasks.Task`1<int32>::GetAwaiter()
615-
IL_000c: stloc.0
616-
IL_000d: ldloca.s 0
617-
IL_000f: call instance !0 valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<int32>::GetResult()
618-
*/
619-
ilBuilder.EmitInstanceCallThisValueOnStack(callResultLocal, consumableInfo.GetAwaiterMethod);
620-
ilBuilder.EmitInstanceCallThisValueOnStack(awaiterLocal, consumableInfo.GetResultMethod);
602+
// BenchmarkDotNet.Helpers.AwaitHelper.GetResult(...);
603+
IL_000e: call !!0 BenchmarkDotNet.Helpers.AwaitHelper::GetResult<int32>(valuetype [System.Runtime]System.Threading.Tasks.ValueTask`1<!!0>)
604+
*/
605+
606+
ilBuilder.Emit(OpCodes.Call, consumableInfo.GetResultMethod);
621607

622608
/*
623609
IL_0014: ret
@@ -833,19 +819,6 @@ .locals init (
833819
var skipFirstArg = workloadMethod.IsStatic;
834820
var argLocals = EmitDeclareArgLocals(ilBuilder, skipFirstArg);
835821

836-
LocalBuilder? callResultLocal = null;
837-
LocalBuilder? awaiterLocal = null;
838-
if (consumableInfo.IsAwaitable)
839-
{
840-
var callResultType = consumableInfo.OriginMethodReturnType;
841-
var awaiterType = consumableInfo.GetAwaiterMethod?.ReturnType
842-
?? throw new InvalidOperationException($"Bug: {nameof(consumableInfo.GetAwaiterMethod)} is null");
843-
callResultLocal =
844-
ilBuilder.DeclareOptionalLocalForInstanceCall(callResultType, consumableInfo.GetAwaiterMethod);
845-
awaiterLocal =
846-
ilBuilder.DeclareOptionalLocalForInstanceCall(awaiterType, consumableInfo.GetResultMethod);
847-
}
848-
849822
consumeEmitter.DeclareDisassemblyDiagnoserLocals(ilBuilder);
850823

851824
var notElevenLabel = ilBuilder.DefineLabel();
@@ -870,29 +843,27 @@ .locals init (
870843
EmitLoadArgFieldsToLocals(ilBuilder, argLocals, skipFirstArg);
871844

872845
/*
873-
// return TaskSample(_argField) ... ;
874-
IL_0011: ldarg.0
875-
IL_0012: ldloc.0
876-
IL_0013: call instance class [mscorlib]System.Threading.Tasks.Task`1<int32> [BenchmarkDotNet]BenchmarkDotNet.Samples.SampleBenchmark::TaskSample(int64)
877-
IL_0018: ret
846+
IL_0026: ldarg.0
847+
IL_0027: ldloc.0
848+
IL_0028: ldloc.1
849+
IL_0029: ldloc.2
850+
IL_002a: ldloc.3
851+
IL_002b: call instance class [System.Private.CoreLib]System.Threading.Tasks.Task`1<object> BenchmarkDotNet.Helpers.Runnable_0::WorkloadMethod(string, string, string, string)
878852
*/
879-
880853
if (!workloadMethod.IsStatic)
854+
{
881855
ilBuilder.Emit(OpCodes.Ldarg_0);
856+
}
882857
ilBuilder.EmitLdLocals(argLocals);
883858
ilBuilder.Emit(OpCodes.Call, workloadMethod);
884859

885860
if (consumableInfo.IsAwaitable)
886861
{
887862
/*
888-
// ... .GetAwaiter().GetResult();
889-
IL_0007: callvirt instance valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<!0> class [mscorlib]System.Threading.Tasks.Task`1<int32>::GetAwaiter()
890-
IL_000c: stloc.0
891-
IL_000d: ldloca.s 0
892-
IL_000f: call instance !0 valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<int32>::GetResult()
893-
*/
894-
ilBuilder.EmitInstanceCallThisValueOnStack(callResultLocal, consumableInfo.GetAwaiterMethod);
895-
ilBuilder.EmitInstanceCallThisValueOnStack(awaiterLocal, consumableInfo.GetResultMethod);
863+
// BenchmarkDotNet.Helpers.AwaitHelper.GetResult(...);
864+
IL_000e: call !!0 BenchmarkDotNet.Helpers.AwaitHelper::GetResult<int32>(valuetype [System.Runtime]System.Threading.Tasks.ValueTask`1<!!0>)
865+
*/
866+
ilBuilder.Emit(OpCodes.Call, consumableInfo.GetResultMethod);
896867
}
897868

898869
/*

src/BenchmarkDotNet/Toolchains/InProcess/NoEmit/BenchmarkActionFactory_Implementations.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ public BenchmarkActionTask(object instance, MethodInfo method, int unrollFactor)
118118
private void Overhead() { }
119119

120120
// must be kept in sync with TaskDeclarationsProvider.TargetMethodDelegate
121-
private void ExecuteBlocking() => startTaskCallback.Invoke().GetAwaiter().GetResult();
121+
private void ExecuteBlocking() => Helpers.AwaitHelper.GetResult(startTaskCallback.Invoke());
122122

123123
[MethodImpl(CodeGenHelper.AggressiveOptimizationOption)]
124124
private void WorkloadActionUnroll(long repeatCount)
@@ -165,7 +165,7 @@ public BenchmarkActionTask(object instance, MethodInfo method, int unrollFactor)
165165
private T Overhead() => default;
166166

167167
// must be kept in sync with GenericTaskDeclarationsProvider.TargetMethodDelegate
168-
private T ExecuteBlocking() => startTaskCallback().GetAwaiter().GetResult();
168+
private T ExecuteBlocking() => Helpers.AwaitHelper.GetResult(startTaskCallback.Invoke());
169169

170170
private void InvokeSingleHardcoded() => result = callback();
171171

@@ -217,7 +217,7 @@ public BenchmarkActionValueTask(object instance, MethodInfo method, int unrollFa
217217
private T Overhead() => default;
218218

219219
// must be kept in sync with GenericTaskDeclarationsProvider.TargetMethodDelegate
220-
private T ExecuteBlocking() => startTaskCallback().GetAwaiter().GetResult();
220+
private T ExecuteBlocking() => Helpers.AwaitHelper.GetResult(startTaskCallback.Invoke());
221221

222222
private void InvokeSingleHardcoded() => result = callback();
223223

src/BenchmarkDotNet/Validators/ExecutionValidatorBase.cs

+3-15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Threading.Tasks;
66
using BenchmarkDotNet.Attributes;
77
using BenchmarkDotNet.Extensions;
8+
using BenchmarkDotNet.Helpers;
89
using BenchmarkDotNet.Running;
910

1011
namespace BenchmarkDotNet.Validators
@@ -130,21 +131,8 @@ private void TryToGetTaskResult(object result)
130131
return;
131132
}
132133

133-
var returnType = result.GetType();
134-
if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
135-
{
136-
var asTaskMethod = result.GetType().GetMethod("AsTask");
137-
result = asTaskMethod.Invoke(result, null);
138-
}
139-
140-
if (result is Task task)
141-
{
142-
task.GetAwaiter().GetResult();
143-
}
144-
else if (result is ValueTask valueTask)
145-
{
146-
valueTask.GetAwaiter().GetResult();
147-
}
134+
AwaitHelper.GetGetResultMethod(result.GetType())
135+
?.Invoke(null, new[] { result });
148136
}
149137

150138
private bool TryToSetParamsFields(object benchmarkTypeInstance, List<ValidationError> errors)

0 commit comments

Comments
 (0)