Skip to content

Implement struct enumerator for Tensor<T> #46497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 15, 2021
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ public virtual void Fill(T value) { }
protected virtual int IndexOf(T item) { throw null; }
public abstract System.Numerics.Tensors.Tensor<T> Reshape(System.ReadOnlySpan<int> dimensions);
public abstract void SetValue(int index, T value);
public struct Enumerator : System.Collections.Generic.IEnumerator<T>
{
public T Current { get; private set; }
object? System.Collections.IEnumerator.Current => throw null;
public bool MoveNext() => throw null;
public void Reset() { }
public void Dispose() { }
}
public Enumerator GetEnumerator() => throw null;
void System.Collections.Generic.ICollection<T>.Add(T item) { }
void System.Collections.Generic.ICollection<T>.Clear() { }
bool System.Collections.Generic.ICollection<T>.Contains(T item) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,62 @@ public virtual T this[ReadOnlySpan<int> indices]
/// <param name="value">The new value to set at the specified position in this Tensor.</param>
public abstract void SetValue(int index, T value);

/// <summary>
/// The type that implements enumerators for <see cref="Tensor{T}"/> instances.
/// </summary>
public struct Enumerator : IEnumerator<T>
{
private readonly Tensor<T> _tensor;
private int _index;

internal Enumerator(Tensor<T> tensor)
{
Debug.Assert(tensor != null);

_tensor = tensor;
_index = 0;
Current = default;
}

public T Current { get; private set; }

object? IEnumerator.Current => Current;

public bool MoveNext()
{
if (_index < _tensor.Length)
{
Current = _tensor.GetValue(_index);
++_index;
return true;
}
else
{
Current = default;
return false;
}
}

/// <summary>
/// Resets the enumerator to the beginning.
/// </summary>
public void Reset()
{
_index = 0;
Current = default;
}

/// <summary>
/// Disposes the enumerator.
/// </summary>
public void Dispose() { }
}

/// <summary>
/// Gets an enumerator that enumerates the elements of the <see cref="Tensor{T}"/>.
/// </summary>
/// <returns>An enumerator for the current <see cref="Tensor{T}"/>.</returns>
public Enumerator GetEnumerator() => new Enumerator(this);

#region statics
/// <summary>
Expand All @@ -717,10 +773,7 @@ public static bool Equals(Tensor<T> left, Tensor<T> right)
#endregion

#region IEnumerable members
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable<T>)this).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
#endregion

#region ICollection members
Expand Down Expand Up @@ -828,13 +881,7 @@ void IList.RemoveAt(int index)
#endregion

#region IEnumerable<T> members
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
for (int i = 0; i < Length; i++)
{
yield return GetValue(i);
}
}
IEnumerator<T> IEnumerable<T>.GetEnumerator() => GetEnumerator();
#endregion

#region ICollection<T> members
Expand Down
63 changes: 63 additions & 0 deletions src/libraries/System.Numerics.Tensors/tests/TensorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2419,5 +2419,68 @@ public void TestIReadOnlyTMembers(TensorConstructor constructor)
int expectedIndexValue = constructor.IsReversedStride ? 4 : 2;
Assert.Equal(expectedIndexValue, tensorList[1]);
}

[Theory]
[MemberData(nameof(GetConstructedTensors))]
public void TestGetEnumerator(Tensor<int> tensor)
{
static IEnumerable<int> GetExpected(Tensor<int> tensor)
{
for (int index = 0; index < tensor.Length; ++index)
yield return tensor.GetValue(index);
}

Assert.Equal(GetExpected(tensor), tensor);
}

[Theory]
[MemberData(nameof(GetConstructedTensors))]
public void TestEnumeratorReset(Tensor<int> tensor)
{
static long AdvanceEnumerator(ref Tensor<int>.Enumerator enumerator, long maxCount)
{
long count = 0;
while (count < maxCount && enumerator.MoveNext())
count++;

return count;
}

static void TestStepCountIfInRange(Tensor<int> tensor, long stepCount)
{
if (stepCount < 0 || stepCount > tensor.Length)
return;

var enumerator = tensor.GetEnumerator();
long actualStepCount = AdvanceEnumerator(ref enumerator, stepCount);

Assert.Equal(stepCount, actualStepCount);

enumerator.Reset();

var itemsPostReset = new List<int>();
while (enumerator.MoveNext())
itemsPostReset.Add(enumerator.Current);

Assert.Equal(tensor, itemsPostReset);
}

TestStepCountIfInRange(tensor, 1);
TestStepCountIfInRange(tensor, tensor.Length - 1);
TestStepCountIfInRange(tensor, tensor.Length / 4);
TestStepCountIfInRange(tensor, tensor.Length - tensor.Length / 4);
TestStepCountIfInRange(tensor, tensor.Length / 2);
TestStepCountIfInRange(tensor, tensor.Length);
}

[Theory]
[MemberData(nameof(GetConstructedTensors))]
public void TestEnumeratorDispose_DoesNotThrow(Tensor<int> tensor)
{
var enumerator = tensor.GetEnumerator();

enumerator.Dispose();
enumerator.Dispose();
}
}
}
22 changes: 22 additions & 0 deletions src/libraries/System.Numerics.Tensors/tests/TensorTestsBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Linq;

namespace System.Numerics.Tensors.Tests
{
Expand Down Expand Up @@ -141,6 +142,27 @@ public static IEnumerable<object[]> GetTensorAndResultConstructor()
}
}

public static IEnumerable<object[]> GetConstructedTensors()
{
foreach (var ctor in GetSingleTensorConstructors().Select(x => (TensorConstructor)x[0]))
{
yield return new object[] { ctor.CreateFromArray<int>(Array.Empty<int>()) };
yield return new object[] { ctor.CreateFromArray<int>(new[] { 7 }) };
yield return new object[] { ctor.CreateFromArray<int>(new[] { 7, 14 }) };
yield return new object[] { ctor.CreateFromArray<int>(new[] { 7, 14, 21 }) };
yield return new object[]
{
ctor.CreateFromArray<int>(new[,]
{
{ 3, 6, 9 },
{ 5, 10, 15 },
{ 7, 14, 21 },
{ 11, 22, 33 }
})
};
}
}

public static NativeMemory<T> NativeMemoryFromArray<T>(T[] array)
{
return NativeMemoryFromArray<T>((Array)array);
Expand Down