diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs index 9af0f37186cca6..b54d42ebf768a2 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs @@ -119,6 +119,15 @@ public virtual void Fill(T value) { } protected virtual int IndexOf(T item) { throw null; } public abstract System.Numerics.Tensors.Tensor Reshape(System.ReadOnlySpan dimensions); public abstract void SetValue(int index, T value); + public struct Enumerator : System.Collections.Generic.IEnumerator + { + public T Current { get; private set; } + object? System.Collections.IEnumerator.Current => throw null; + public bool MoveNext() => throw null; + public void Reset() { } + public void Dispose() { } + } + public Enumerator GetEnumerator() => throw null; void System.Collections.Generic.ICollection.Add(T item) { } void System.Collections.Generic.ICollection.Clear() { } bool System.Collections.Generic.ICollection.Contains(T item) { throw null; } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/Tensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/Tensor.cs index bcb3815320dbb6..270da201c53b48 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/Tensor.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/Tensor.cs @@ -691,6 +691,62 @@ public virtual T this[ReadOnlySpan indices] /// The new value to set at the specified position in this Tensor. public abstract void SetValue(int index, T value); + /// + /// The type that implements enumerators for instances. + /// + public struct Enumerator : IEnumerator + { + private readonly Tensor _tensor; + private int _index; + + internal Enumerator(Tensor tensor) + { + Debug.Assert(tensor != null); + + _tensor = tensor; + _index = 0; + Current = default; + } + + public T Current { get; private set; } + + object? IEnumerator.Current => Current; + + public bool MoveNext() + { + if (_index < _tensor.Length) + { + Current = _tensor.GetValue(_index); + ++_index; + return true; + } + else + { + Current = default; + return false; + } + } + + /// + /// Resets the enumerator to the beginning. + /// + public void Reset() + { + _index = 0; + Current = default; + } + + /// + /// Disposes the enumerator. + /// + public void Dispose() { } + } + + /// + /// Gets an enumerator that enumerates the elements of the . + /// + /// An enumerator for the current . + public Enumerator GetEnumerator() => new Enumerator(this); #region statics /// @@ -717,10 +773,7 @@ public static bool Equals(Tensor left, Tensor right) #endregion #region IEnumerable members - IEnumerator IEnumerable.GetEnumerator() - { - return ((IEnumerable)this).GetEnumerator(); - } + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); #endregion #region ICollection members @@ -828,13 +881,7 @@ void IList.RemoveAt(int index) #endregion #region IEnumerable members - IEnumerator IEnumerable.GetEnumerator() - { - for (int i = 0; i < Length; i++) - { - yield return GetValue(i); - } - } + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); #endregion #region ICollection members diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs index 2c9c7ba33bc040..0a5d8224c76635 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs @@ -2419,5 +2419,68 @@ public void TestIReadOnlyTMembers(TensorConstructor constructor) int expectedIndexValue = constructor.IsReversedStride ? 4 : 2; Assert.Equal(expectedIndexValue, tensorList[1]); } + + [Theory] + [MemberData(nameof(GetConstructedTensors))] + public void TestGetEnumerator(Tensor tensor) + { + static IEnumerable GetExpected(Tensor tensor) + { + for (int index = 0; index < tensor.Length; ++index) + yield return tensor.GetValue(index); + } + + Assert.Equal(GetExpected(tensor), tensor); + } + + [Theory] + [MemberData(nameof(GetConstructedTensors))] + public void TestEnumeratorReset(Tensor tensor) + { + static long AdvanceEnumerator(ref Tensor.Enumerator enumerator, long maxCount) + { + long count = 0; + while (count < maxCount && enumerator.MoveNext()) + count++; + + return count; + } + + static void TestStepCountIfInRange(Tensor tensor, long stepCount) + { + if (stepCount < 0 || stepCount > tensor.Length) + return; + + var enumerator = tensor.GetEnumerator(); + long actualStepCount = AdvanceEnumerator(ref enumerator, stepCount); + + Assert.Equal(stepCount, actualStepCount); + + enumerator.Reset(); + + var itemsPostReset = new List(); + while (enumerator.MoveNext()) + itemsPostReset.Add(enumerator.Current); + + Assert.Equal(tensor, itemsPostReset); + } + + TestStepCountIfInRange(tensor, 1); + TestStepCountIfInRange(tensor, tensor.Length - 1); + TestStepCountIfInRange(tensor, tensor.Length / 4); + TestStepCountIfInRange(tensor, tensor.Length - tensor.Length / 4); + TestStepCountIfInRange(tensor, tensor.Length / 2); + TestStepCountIfInRange(tensor, tensor.Length); + } + + [Theory] + [MemberData(nameof(GetConstructedTensors))] + public void TestEnumeratorDispose_DoesNotThrow(Tensor tensor) + { + var enumerator = tensor.GetEnumerator(); + + enumerator.Dispose(); + enumerator.Dispose(); + } } } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorTestsBase.cs b/src/libraries/System.Numerics.Tensors/tests/TensorTestsBase.cs index 26ed47ec639ea4..9774dd22662e6a 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorTestsBase.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorTestsBase.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Linq; namespace System.Numerics.Tensors.Tests { @@ -141,6 +142,27 @@ public static IEnumerable GetTensorAndResultConstructor() } } + public static IEnumerable GetConstructedTensors() + { + foreach (var ctor in GetSingleTensorConstructors().Select(x => (TensorConstructor)x[0])) + { + yield return new object[] { ctor.CreateFromArray(Array.Empty()) }; + yield return new object[] { ctor.CreateFromArray(new[] { 7 }) }; + yield return new object[] { ctor.CreateFromArray(new[] { 7, 14 }) }; + yield return new object[] { ctor.CreateFromArray(new[] { 7, 14, 21 }) }; + yield return new object[] + { + ctor.CreateFromArray(new[,] + { + { 3, 6, 9 }, + { 5, 10, 15 }, + { 7, 14, 21 }, + { 11, 22, 33 } + }) + }; + } + } + public static NativeMemory NativeMemoryFromArray(T[] array) { return NativeMemoryFromArray((Array)array);