diff --git a/Project.toml b/Project.toml index 41b4eed..d982bda 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AxisSets" uuid = "a1a1544e-ba16-4f6d-8861-e833517b754e" authors = ["Invenia Technical Computing Corporation"] -version = "0.1.4" +version = "0.1.5" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" diff --git a/src/patterns.jl b/src/patterns.jl index aa99cc8..d2a6c61 100644 --- a/src/patterns.jl +++ b/src/patterns.jl @@ -88,8 +88,12 @@ function Base.in(item::Tuple, pattern::Pattern) item_val, item_st = item_iter pat_val, pat_st = pat_iter - # Iterate as normal if the pattern value matches or it's :_ - if item_val == pat_val || pat_val === :_ + # Iterate as normal if the pattern value matches, is a subtype or it's :_ + if ( + (item_val isa Type && pat_val isa Type && item_val <: pat_val) || + item_val == pat_val || + pat_val === :_ + ) pat_iter = iterate(pattern.segments, pat_st) item_iter = iterate(item, item_st) # Look ahead when we see a multi-value wildcard to see if the next value matches diff --git a/test/patterns.jl b/test/patterns.jl index 602e732..1840998 100644 --- a/test/patterns.jl +++ b/test/patterns.jl @@ -154,5 +154,54 @@ (t1, 1, "temperature", :time), (t1, 1, "temperature", :id), ] + + @testset "Subtype matching" begin + t1 = Float64 + t2 = Int + items = [ + (t1, 1, "prices", :time), + (t1, 1, "prices", :id), + (t1, 1, "prices", :lag), + (t1, 1, "load", :time), + (t1, 1, "load", :id), + (t1, 1, "temperature", :time), + (t1, 1, "temperature", :id), + (t1, 2, "prices", :time), + (t1, 2, "prices", :id), + (t2, 1, "prices", :time), + (t2, 1, "prices", :id), + (t2, 1, "prices", :lag), + (t2, 1, "load", :time), + (t2, 1, "load", :id), + (t2, 1, "temperature", :time), + (t2, 1, "temperature", :id), + (t2, 2, "prices", :time), + (t2, 2, "prices", :id), + ] + + pattern = Pattern(AbstractFloat, 1, :__) + @test filter(in(pattern), items) == [ + (t1, 1, "prices", :time), + (t1, 1, "prices", :id), + (t1, 1, "prices", :lag), + (t1, 1, "load", :time), + (t1, 1, "load", :id), + (t1, 1, "temperature", :time), + (t1, 1, "temperature", :id), + ] + + pattern = Pattern(Integer, 1, :__) + @test filter(in(pattern), items) == [ + (t2, 1, "prices", :time), + (t2, 1, "prices", :id), + (t2, 1, "prices", :lag), + (t2, 1, "load", :time), + (t2, 1, "load", :id), + (t2, 1, "temperature", :time), + (t2, 1, "temperature", :id), + ] + + @test filter(in(Pattern(Real, :__)), items) == items + end end end