Skip to content

Commit 4a8ec1f

Browse files
NewellClarkgfoidl
andauthored
Implement struct enumerator for Tensor<T> (#46497)
* Added struct enumerator to Tensor<T> Fix issue #28391 * Added struct enumerator to Tensor<T> Added unit test for element enumeration. Fix issue #28391 * Added xml doc comments * Apply suggested edit to System.Numerics.Tensor.Enumerator Co-authored-by: Günther Foidl <[email protected]> * Fixed naming * Forgot to change back Current to auto-prop * Apply suggested changes -Make Reset and Dispose public -Fix null ref bug with Reset -Added tests for Reset and Dispose Co-authored-by: Günther Foidl <[email protected]>
1 parent 5ef6a06 commit 4a8ec1f

File tree

4 files changed

+152
-11
lines changed

4 files changed

+152
-11
lines changed

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs

+9
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ public virtual void Fill(T value) { }
119119
protected virtual int IndexOf(T item) { throw null; }
120120
public abstract System.Numerics.Tensors.Tensor<T> Reshape(System.ReadOnlySpan<int> dimensions);
121121
public abstract void SetValue(int index, T value);
122+
public struct Enumerator : System.Collections.Generic.IEnumerator<T>
123+
{
124+
public T Current { get; private set; }
125+
object? System.Collections.IEnumerator.Current => throw null;
126+
public bool MoveNext() => throw null;
127+
public void Reset() { }
128+
public void Dispose() { }
129+
}
130+
public Enumerator GetEnumerator() => throw null;
122131
void System.Collections.Generic.ICollection<T>.Add(T item) { }
123132
void System.Collections.Generic.ICollection<T>.Clear() { }
124133
bool System.Collections.Generic.ICollection<T>.Contains(T item) { throw null; }

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/Tensor.cs

+58-11
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,62 @@ public virtual T this[ReadOnlySpan<int> indices]
691691
/// <param name="value">The new value to set at the specified position in this Tensor.</param>
692692
public abstract void SetValue(int index, T value);
693693

694+
/// <summary>
695+
/// The type that implements enumerators for <see cref="Tensor{T}"/> instances.
696+
/// </summary>
697+
public struct Enumerator : IEnumerator<T>
698+
{
699+
private readonly Tensor<T> _tensor;
700+
private int _index;
701+
702+
internal Enumerator(Tensor<T> tensor)
703+
{
704+
Debug.Assert(tensor != null);
705+
706+
_tensor = tensor;
707+
_index = 0;
708+
Current = default;
709+
}
710+
711+
public T Current { get; private set; }
712+
713+
object? IEnumerator.Current => Current;
714+
715+
public bool MoveNext()
716+
{
717+
if (_index < _tensor.Length)
718+
{
719+
Current = _tensor.GetValue(_index);
720+
++_index;
721+
return true;
722+
}
723+
else
724+
{
725+
Current = default;
726+
return false;
727+
}
728+
}
729+
730+
/// <summary>
731+
/// Resets the enumerator to the beginning.
732+
/// </summary>
733+
public void Reset()
734+
{
735+
_index = 0;
736+
Current = default;
737+
}
738+
739+
/// <summary>
740+
/// Disposes the enumerator.
741+
/// </summary>
742+
public void Dispose() { }
743+
}
744+
745+
/// <summary>
746+
/// Gets an enumerator that enumerates the elements of the <see cref="Tensor{T}"/>.
747+
/// </summary>
748+
/// <returns>An enumerator for the current <see cref="Tensor{T}"/>.</returns>
749+
public Enumerator GetEnumerator() => new Enumerator(this);
694750

695751
#region statics
696752
/// <summary>
@@ -717,10 +773,7 @@ public static bool Equals(Tensor<T> left, Tensor<T> right)
717773
#endregion
718774

719775
#region IEnumerable members
720-
IEnumerator IEnumerable.GetEnumerator()
721-
{
722-
return ((IEnumerable<T>)this).GetEnumerator();
723-
}
776+
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
724777
#endregion
725778

726779
#region ICollection members
@@ -828,13 +881,7 @@ void IList.RemoveAt(int index)
828881
#endregion
829882

830883
#region IEnumerable<T> members
831-
IEnumerator<T> IEnumerable<T>.GetEnumerator()
832-
{
833-
for (int i = 0; i < Length; i++)
834-
{
835-
yield return GetValue(i);
836-
}
837-
}
884+
IEnumerator<T> IEnumerable<T>.GetEnumerator() => GetEnumerator();
838885
#endregion
839886

840887
#region ICollection<T> members

src/libraries/System.Numerics.Tensors/tests/TensorTests.cs

+63
Original file line numberDiff line numberDiff line change
@@ -2419,5 +2419,68 @@ public void TestIReadOnlyTMembers(TensorConstructor constructor)
24192419
int expectedIndexValue = constructor.IsReversedStride ? 4 : 2;
24202420
Assert.Equal(expectedIndexValue, tensorList[1]);
24212421
}
2422+
2423+
[Theory]
2424+
[MemberData(nameof(GetConstructedTensors))]
2425+
public void TestGetEnumerator(Tensor<int> tensor)
2426+
{
2427+
static IEnumerable<int> GetExpected(Tensor<int> tensor)
2428+
{
2429+
for (int index = 0; index < tensor.Length; ++index)
2430+
yield return tensor.GetValue(index);
2431+
}
2432+
2433+
Assert.Equal(GetExpected(tensor), tensor);
2434+
}
2435+
2436+
[Theory]
2437+
[MemberData(nameof(GetConstructedTensors))]
2438+
public void TestEnumeratorReset(Tensor<int> tensor)
2439+
{
2440+
static long AdvanceEnumerator(ref Tensor<int>.Enumerator enumerator, long maxCount)
2441+
{
2442+
long count = 0;
2443+
while (count < maxCount && enumerator.MoveNext())
2444+
count++;
2445+
2446+
return count;
2447+
}
2448+
2449+
static void TestStepCountIfInRange(Tensor<int> tensor, long stepCount)
2450+
{
2451+
if (stepCount < 0 || stepCount > tensor.Length)
2452+
return;
2453+
2454+
var enumerator = tensor.GetEnumerator();
2455+
long actualStepCount = AdvanceEnumerator(ref enumerator, stepCount);
2456+
2457+
Assert.Equal(stepCount, actualStepCount);
2458+
2459+
enumerator.Reset();
2460+
2461+
var itemsPostReset = new List<int>();
2462+
while (enumerator.MoveNext())
2463+
itemsPostReset.Add(enumerator.Current);
2464+
2465+
Assert.Equal(tensor, itemsPostReset);
2466+
}
2467+
2468+
TestStepCountIfInRange(tensor, 1);
2469+
TestStepCountIfInRange(tensor, tensor.Length - 1);
2470+
TestStepCountIfInRange(tensor, tensor.Length / 4);
2471+
TestStepCountIfInRange(tensor, tensor.Length - tensor.Length / 4);
2472+
TestStepCountIfInRange(tensor, tensor.Length / 2);
2473+
TestStepCountIfInRange(tensor, tensor.Length);
2474+
}
2475+
2476+
[Theory]
2477+
[MemberData(nameof(GetConstructedTensors))]
2478+
public void TestEnumeratorDispose_DoesNotThrow(Tensor<int> tensor)
2479+
{
2480+
var enumerator = tensor.GetEnumerator();
2481+
2482+
enumerator.Dispose();
2483+
enumerator.Dispose();
2484+
}
24222485
}
24232486
}

src/libraries/System.Numerics.Tensors/tests/TensorTestsBase.cs

+22
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Collections.Generic;
5+
using System.Linq;
56

67
namespace System.Numerics.Tensors.Tests
78
{
@@ -141,6 +142,27 @@ public static IEnumerable<object[]> GetTensorAndResultConstructor()
141142
}
142143
}
143144

145+
public static IEnumerable<object[]> GetConstructedTensors()
146+
{
147+
foreach (var ctor in GetSingleTensorConstructors().Select(x => (TensorConstructor)x[0]))
148+
{
149+
yield return new object[] { ctor.CreateFromArray<int>(Array.Empty<int>()) };
150+
yield return new object[] { ctor.CreateFromArray<int>(new[] { 7 }) };
151+
yield return new object[] { ctor.CreateFromArray<int>(new[] { 7, 14 }) };
152+
yield return new object[] { ctor.CreateFromArray<int>(new[] { 7, 14, 21 }) };
153+
yield return new object[]
154+
{
155+
ctor.CreateFromArray<int>(new[,]
156+
{
157+
{ 3, 6, 9 },
158+
{ 5, 10, 15 },
159+
{ 7, 14, 21 },
160+
{ 11, 22, 33 }
161+
})
162+
};
163+
}
164+
}
165+
144166
public static NativeMemory<T> NativeMemoryFromArray<T>(T[] array)
145167
{
146168
return NativeMemoryFromArray<T>((Array)array);

0 commit comments

Comments
 (0)