Skip to content

Commit

Permalink
Use MacroTools
Browse files Browse the repository at this point in the history
  • Loading branch information
mgkurtz committed Jul 3, 2023
1 parent bd7be57 commit 51622b5
Showing 1 changed file with 39 additions and 35 deletions.
74 changes: 39 additions & 35 deletions src/misc/VarNames.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Base.Iterators
import MacroTools as MT

@assert VarName === Union{Symbol, AbstractString, Char}
const VarShape = Union{Missing, Int, Tuple{Vararg{Int}}}
Expand Down Expand Up @@ -94,6 +94,34 @@ _reshape(iter, ::Missing) = popfirst!(iter)
_reshape(iter, n::Int) = collect(Iterators.take(iter, n))
_reshape(iter, dims::Tuple{Vararg{Int}}) = reshape(collect(Iterators.take(iter, prod(dims))), dims)

function _varname_interface(e::Expr, @nospecialize s::Union{Expr, Symbol})
ex = MT.isexpr(e, :(=), :function) ? e : Expr(:(=), e, :())
d = MT.splitdef(ex)

callf = esc(d[:name])
f = esc(MT.postwalk(x -> MT.@capture(x, a_.b_) ? b : x, d[:name]))
wheres = esc.(d[:whereparams])

args = d[:args][begin:end-1]
splitargs = MacroTools.splitarg.(args)
args = esc.(args)
req(all(((_, _, slurp, default),) -> (slurp, default) === (false, nothing), splitargs),
"Default and slurp arguments currently not supported")
req(isempty(d[:kwargs]), "Keyword arguments currently not supported")
argnames = first.(splitargs)
req(all(!isnothing, argnames), "Nameless arguments currently not supported")
argnames = esc.(argnames)

s = esc(s)
argtypes = esc.(a[2] for a in splitargs)
argtypes = :(Tuple{$(argtypes...), $s} where {$(wheres...)})
base = f == callf ?
:(req(hasmethod($f, $argtypes), "base method of `$($f)` for $($argtypes) missing")) :
:($f($(args...), s::$s; kv...) where {$(wheres...)} = $callf($(argnames...), s; kv...))

return f, args, argnames, wheres, base
end

@doc raw"""
@varnames_interface [M.]f(args..., varnames)
Expand Down Expand Up @@ -172,10 +200,10 @@ julia> v = [1,4,1]; @f "variable dims" x[v...]; x
```
"""
macro varnames_interface(e::Expr, options...)
f, args, argnames, base = _varname_interface(e, :(Vector{Symbol}))
f, args, argnames, wheres, base = _varname_interface(e, :(Vector{Symbol}))
fancy_method = quote
$f($(args...), s::VarNames...; kv...) = $f($(args...), s; kv...)
function $f($(args...), s::Tuple{Vararg{VarNames}}; kv...)
$f($(args...), s::VarNames...; kv...) where {$(wheres...)} = $f($(argnames...), s; kv...)
function $f($(args...), s::Tuple{Vararg{VarNames}}; kv...) where {$(wheres...)}
X, gens = $f($(argnames...), variable_names(s...); kv...)
return X, reshape_to_varnames(gens, s...)...
end
Expand All @@ -191,7 +219,7 @@ macro varnames_interface(e::Expr, options...)
end

n isa Symbol || return :($base; $fancy_method)
fancy_n_method = :($f($(args...), $n::Int, s::VarName=:x; kv...) = $f($(argnames...), Symbol.(s, $en); kv...))
fancy_n_method = :($f($(args...), $n::Int, s::VarName=:x; kv...) where {$(wheres...)} = $f($(argnames...), Symbol.(s, $en); kv...))

opts[:macros] === :(:no) && return :($base; $fancy_method; $fancy_n_method)
ss, xs = opts[:macros] === :(:all) ? (:(s::Union{Expr, Symbol}...), :(_expr_pairs(s))) :
Expand All @@ -211,43 +239,20 @@ macro varnames_interface(e::Expr, options...)
return :($base; $fancy_method; $fancy_n_method; $fancy_macro)
end

function _varname_interface(e::Expr, @nospecialize s::Union{Expr, Symbol})
req(Base.isexpr(e, :call), "Argument `$e` must be a function call")
callf = esc(e.args[1])
f = esc(unqualified(e.args[1]))
# TODO cope with `where` clause, (and perhaps with optional arguments and keyvalues)
args = e.args[2:end-1]
argnames = esc.(argname.(args))
argtypes = Expr(:curly, :Tuple, argtype.(args)..., s)
args = esc.(args)
base = f == callf ?
:(req(hasmethod($f, $argtypes), "base method of `$($f)` for $($argtypes) missing")) :
:($f($(args...), s::$s; kv...) = $callf($(argnames...), s; kv...))
return f, args, argnames, base
end

function parse_options(kvs::Tuple{Vararg{Expr}}, default::Dict{Symbol}, valid::Dict{Symbol, <:Vector} = Dict{Symbol, Vector{Any}}())
result = Dict{Symbol, Any}(default)
for o in kvs
req(Base.isexpr(o, :(=), 2), "only key value options allowed")
k, v = o.args
MT.@capture(o, k_ = v_) || error("only key value options allowed")
req(k in keys(result), "invalid key value option key `$k`")
k in keys(valid) && req(v in valid[k], "invalid option `$v` to key `$k`")
result[k] = v
end
return result
end

unqualified(a::Symbol) = a
unqualified(e::Expr) = (req(Base.isexpr(e, :., 2), "Not a name: `$e`"); e.args[2].value) :: Symbol
argname(a::Symbol) = a
argname(e::Expr) = (req(Base.isexpr(e, :(::), 2), "Not a (possibly type asserted) Symbol: `$e`"); e.args[1]) :: Symbol
argtype(a::Symbol) = :Any
argtype(e::Expr) = (req(Base.isexpr(e, :(::)), "Not a type assertion or Symbol: `$e`"); e.args[end]) :: Union{Expr, Symbol}

_expr_pairs(a::Tuple{Vararg{Union{Expr, Symbol}}}) = _expr_pair.(a)
_expr_pairs((a,)::Tuple{Expr}) = Base.isexpr(a, :tuple) ? _expr_pair.(a.args) : (_expr_pair(a),) # for `@f args... (varnames...)` variant
_expr_pair(e::Expr) = (req(Base.isexpr(e, :ref), "variable name must be like `x` or `x[...]`, not `$e`"); e.args[1] => Expr(:tuple, esc.(e.args[2:end])...))
_expr_pairs(es::Tuple{Vararg{Union{Expr, Symbol}}}) = _expr_pair.(es)
_expr_pairs((e,)::Tuple{Expr}) = MT.@capture(e, (es__,)) ? _expr_pair.(es) : (_expr_pair(e),) # for `@f args... (varnames...,)` variant
_expr_pair(e::Expr) = MT.@capture(e, x_[a__]) ? x => :($(esc.(a)...),) : error("variable name must be like `x` or `x[...]`, not `$e`")
_expr_pair(s::Symbol) = s => missing

@doc raw"""
Expand Down Expand Up @@ -294,8 +299,8 @@ julia> x
```
"""
macro varname_interface(e::Expr)
f, args, argnames, base = _varname_interface(e, :Symbol)
fancy_method = :($f($(args...), s::Union{AbstractString, Char}; kv...) = $f($(argnames...), Symbol(s); kv...))
f, args, argnames, wheres, base = _varname_interface(e, :Symbol)
fancy_method = :($f($(args...), s::Union{AbstractString, Char}; kv...) where {$(wheres...)} = $f($(argnames...), Symbol(s); kv...))
fancy_macro = :(
macro $f($(argnames...), s::Symbol)
quote
Expand Down Expand Up @@ -330,4 +335,3 @@ polynomial_ring(R::Ring, s::Symbol; kv...) = invoke(polynomial_ring, Tuple{NCRin
polynomial_ring(R::Ring, s::Union{AbstractString, Char}; kv...) = polynomial_ring(R, Symbol(s); kv...)

# TODO: weights in `graded_polynomial_ring` and `power_series_ring`
# TODO: Cope with Julia 1.6 not having `isexpr`. Maybe use MacroTools.

0 comments on commit 51622b5

Please sign in to comment.