Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Mar 9, 2024
1 parent e591a7d commit 2f571cc
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 41 deletions.
18 changes: 16 additions & 2 deletions src/walks.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
_map(f, x...) = map(f, x...)
function _map(f, x, ys...)
check_lenghts(x, ys...) || error("all arguments must have at least the same length of the firs one")
map(f, x, ys...)
end

function check_lenghts(x, ys...)
n = length(x)
return all(y -> length(y) >= n, ys)
end

_map(f, x::Dict, ys...) = Dict(k => f(v, (y[k] for y in ys)...) for (k, v) in x)

_values(x) = x
Expand All @@ -9,6 +18,7 @@ _keys(x::Tuple) = (keys(x)...,)
_keys(x::AbstractArray) = collect(keys(x))
_keys(x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(Ks)


"""
AbstractWalk
Expand Down Expand Up @@ -101,7 +111,11 @@ See [`fmapstructure`](@ref) for more information.
"""
struct StructuralWalk <: AbstractWalk end

(::StructuralWalk)(recurse, x) = _map(recurse, children(x))
function (::StructuralWalk)(recurse, x, ys...)
x_children = children(x)
ys_children = map(children, ys)
return _map(recurse, x_children, ys_children...)
end

struct StructuralWalkWithPath <: AbstractWalk end

Expand Down
79 changes: 40 additions & 39 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,37 @@ end
@test fmap(.*, m1, foo1, n1) == (x = [4*7, 2*5*8], y = 3*6*9)
end


@testset "fmapstructure(f, x, y)" begin
m1 = Foo([1,2], 3)
n1 = Foo([4,5], 6)
@test fmapstructure(+, m1, n1) == (x = [5, 7], y = 9)

# Mismatched trees should be an error
m2 = (x = [1,2], y = (a = [3,4], b = 5))
n2 = (x = [6,7], y = 8)
@test_throws Exception fmapstructure(firsttuple, m2, n2)

# The cache uses IDs from the first argument
shared = [1,2,3]
m3 = (x = shared, y = [4,5,6], z = shared)
n3 = (x = shared, y = shared, z = [7,8,9])
@test fmapstructure(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6])
z3 = fmapstructure(+, m3, n3)
@test z3.x === z3.z

# Pruning of duplicates:
@test fmapstructure(+, m3, n3; prune = nothing) == (x = [2,4,6], y = [5,7,9], z = nothing)

# More than two arguments:
z4 = fmapstructure(+, m3, n3, m3, n3)
@test z4 == fmapstructure(x -> 2x, z3)
@test z4.x === z4.z

foo1 = Foo([7,8], 9)
@test fmapstructure(.*, foo1, m1, n1) == (x = [4*7, 2*5*8], y = 3*6*9)
end

@testset "old test update.jl" begin
struct M{F,T,S}
σ::F
Expand Down Expand Up @@ -443,7 +474,7 @@ end
# Mismatched trees should be an error
m2 = (x = [1,2], y = (a = [3,4], b = 5))
n2 = (x = [6,7], y = 8)
@test_broken fmap_with_path((kp, x, y) -> x, m2, n2) isa Exception # ERROR: type Int64 has no field a
@test_throws Exception fmap_with_path((kp, x, y) -> x, m2, n2)

# The cache uses IDs from the first argument
shared = [1,2,3]
Expand Down Expand Up @@ -496,42 +527,12 @@ end
end


# @testset "fmap(f, x, y)" begin
# m1 = (x = [1,2], y = 3)
# n1 = (x = [4,5], y = 6)
# @test fmap(+, m1, n1) == (x = [5, 7], y = 9)

# # Reconstruction type comes from the first argument
# foo1 = Foo([7,8], 9)
# @test fmap(+, m1, foo1) == (x = [8, 10], y = 12)
# @test fmap(+, foo1, n1) isa Foo
# @test fmap(+, foo1, n1).x == [11, 13]

# # Mismatched trees should be an error
# m2 = (x = [1,2], y = (a = [3,4], b = 5))
# n2 = (x = [6,7], y = 8)
# @test_throws Exception fmap(first∘tuple, m2, n2) # ERROR: type Int64 has no field a
# @test_throws Exception fmap(first∘tuple, m2, n2)

# # The cache uses IDs from the first argument
# shared = [1,2,3]
# m3 = (x = shared, y = [4,5,6], z = shared)
# n3 = (x = shared, y = shared, z = [7,8,9])
# @test fmap(+, m3, n3) == (x = [2, 4, 6], y = [5, 7, 9], z = [2, 4, 6])
# z3 = fmap(+, m3, n3)
# @test z3.x === z3.z

# # Pruning of duplicates:
# @test fmap(+, m3, n3; prune = nothing) == (x = [2,4,6], y = [5,7,9], z = nothing)

# # More than two arguments:
# z4 = fmap(+, m3, n3, m3, n3)
# @test z4 == fmap(x -> 2x, z3)
# @test z4.x === z4.z

# @test fmap(+, foo1, m1, n1) isa Foo
# @static if VERSION >= v"1.6" # fails on Julia 1.0
# @test fmap(.*, m1, foo1, n1) == (x = [4*7, 2*5*8], y = 3*6*9)
# end
# end
@testset "fmapstructure_with_path(f, x, y)" begin
m1 = (x = [1,2], y = 3)
n1 = (x = [4,5], y = 6)
@test fmapstructure_with_path((kp, x, y) -> x + y, m1, n1) == (x = [5, 7], y = 9)

foo1 = Foo([7,8], 9)
@test fmapstructure_with_path((kp, x, y) -> x + y, foo1, m1) == (x = [8, 10], y = 12)
end
end

0 comments on commit 2f571cc

Please sign in to comment.