diff --git a/src/sequential.jl b/src/sequential.jl index f72b41b..65579e4 100644 --- a/src/sequential.jl +++ b/src/sequential.jl @@ -68,6 +68,15 @@ Create a [`SequentialTransform`](@ref) transform with AbstractTrees.nodevalue(::SequentialTransform) = SequentialTransform AbstractTrees.children(s::SequentialTransform) = s.transforms +# iteration interface +Base.length(s::SequentialTransform) = length(s.transforms) +Base.iterate(s::SequentialTransform, args...) = iterate(s.transforms, args...) + +# indexing interface +Base.getindex(s::SequentialTransform, i) = getindex(s.transforms, i) +Base.firstindex(s::SequentialTransform) = firstindex(s.transforms) +Base.lastindex(s::SequentialTransform) = lastindex(s.transforms) + Base.show(io::IO, s::SequentialTransform) = print(io, join(s.transforms, " → ")) function Base.show(io::IO, ::MIME"text/plain", s::SequentialTransform) diff --git a/test/runtests.jl b/test/runtests.jl index 7fea00f..262acd4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,4 +26,21 @@ using Test T = TestTransform() → TestTransform() @test (T → Identity()) == T @test (Identity() → T) == T + + # sequential + T = TransformsBase.SequentialTransform([TestTransform(), Identity()]) + # iteration interface + @test length(T) == 2 + T1, state = iterate(T) + @test T1 == TestTransform() + T2, state = iterate(T, state) + @test T2 == Identity() + @test isnothing(iterate(T, state)) + # indexing interface + @test T[1] == TestTransform() + @test T[2] == Identity() + @test firstindex(T) == 1 + @test lastindex(T) == 2 + @test T[begin] == TestTransform() + @test T[end] == Identity() end