Skip to content

Commit

Permalink
Consolidate extension API
Browse files Browse the repository at this point in the history
  • Loading branch information
arnodirlam committed Nov 13, 2024
1 parent 2c58cf0 commit 5070352
Show file tree
Hide file tree
Showing 18 changed files with 470 additions and 396 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,6 @@ All syntax except `for` and `with` is supported in `defd` and (private) `defdp`
- All `Kernel` functions **except**:
- `apply/2`, `apply/3`
- `spawn/1`, `spawn_link/1`, `spawn_monitor/1`
- `tap/2`
- `then/2`

### Translatable to database queries

Expand Down
67 changes: 4 additions & 63 deletions lib/dx/date_time.ex
Original file line number Diff line number Diff line change
@@ -1,72 +1,13 @@
defmodule Dx.DateTime do
@moduledoc false

alias Dx.Defd.Ast
alias Dx.Defd.Compiler
use Dx.Defd.Ext

def rewrite(
{:&, meta, [{:/, [], [{{:., _meta2, [DateTime, fun_name]}, _meta3, []}, arity]}]},
state
) do
ast =
cond do
function_exported?(__MODULE__, fun_name, arity) ->
args = Macro.generate_arguments(arity, __MODULE__)
line = meta[:line] || state.line

quote line: line do
{:ok,
fn unquote_splicing(args) ->
unquote(__MODULE__).unquote(fun_name)(unquote_splicing(args))
end}
end

true ->
args = Macro.generate_arguments(arity, __MODULE__)
line = meta[:line] || state.line

quote line: line do
{:ok,
fn unquote_splicing(args) ->
{:ok, unquote(DateTime).unquote(fun_name)(unquote_splicing(args))}
end}
end
end

{ast, state}
end

def rewrite({{:., meta, [DateTime, fun_name]}, meta2, orig_args} = orig, state) do
arity = length(orig_args)

{args, state} = Enum.map_reduce(orig_args, state, &Compiler.normalize_load_unwrap/2)
{args, state} = Compiler.finalize_args(args, state)

ast =
cond do
Enum.all?(args, &Ast.ok?/1) ->
args = Enum.map(args, &Ast.unwrap_inner/1)

quote do
unquote({:ok, {{:., meta, [DateTime, fun_name]}, meta2, args}})
end

function_exported?(DateTime, fun_name, arity) ->
Compiler.compile_error!(meta, state, """
#{fun_name}/#{arity} is not supported by Dx yet.
Please check the issues in the repo, upvote, comment, or create an issue for it.
""")

true ->
{:ok, orig}
end

{ast, state}
@impl true
def __fun_info(_fun_name, _arity, _args) do
%{kernel?: true}
end

import Dx.Defd.Ext

defscope after?(left, right, generate_fallback) do
quote do: {:gt, unquote(left), unquote(right), unquote(generate_fallback.())}
end
Expand Down
4 changes: 2 additions & 2 deletions lib/dx/defd/ast.ex
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ defmodule Dx.Defd.Ast do
def finalize(ast, state) do
prewalk(ast, state, fn
var, state when is_var(var) ->
if var in state.finalized_vars do
if var_id(var) in state.finalized_vars do
{var, state}
else
{{:ok, var}, state} =
Expand Down Expand Up @@ -417,7 +417,7 @@ defmodule Dx.Defd.Ast do
# Helpers

@compile {:inline, with_state: 2}
defp with_state(ast, state), do: {ast, state}
def with_state(ast, state), do: {ast, state}

def closest_meta({_, meta, _}), do: meta
def closest_meta([elem | _rest]), do: closest_meta(elem)
Expand Down
4 changes: 3 additions & 1 deletion lib/dx/defd/ast/loader.ex
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ defmodule Dx.Defd.Ast.Loader do
{assigns, loaders} = Enum.split_with(loaders, &match?(%{ast: {:ok, _}}, &1))

assigns_ast =
Enum.map(assigns, fn %{ast: {:ok, right}} = loader ->
assigns
|> :lists.reverse()
|> Enum.map(fn %{ast: {:ok, right}} = loader ->
{:ok, {:=, [], [loader.data_var, right]}}
end)

Expand Down
2 changes: 1 addition & 1 deletion lib/dx/defd/block.ex
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ defmodule Dx.Defd.Block do
new_vars = MapSet.difference(new_state.args, state.args)

{rest, state} =
if new_vars == %{} do
if Enum.empty?(new_vars) do
normalize_block_body(rest, new_state)
else
case normalize_block_body(rest, new_state) do
Expand Down
18 changes: 9 additions & 9 deletions lib/dx/defd/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ defmodule Dx.Defd.Compiler do
{parts, state} =
Enum.map_reduce(parts, state, fn
{:"::", meta, [ast, {:binary, binary_meta, context}]}, state ->
{ast, state} = normalize_load_unwrap(ast, state)
{ast, state} = normalize_load_unwrap(ast, state) |> Ast.load_scopes()

{:"::", meta, [ast, {:binary, binary_meta, context}]}
|> with_state(state)
Expand Down Expand Up @@ -611,7 +611,7 @@ defmodule Dx.Defd.Compiler do

cond do
rewriter = @rewriters[module] ->
rewriter.rewrite(fun, state)
Dx.Defd.Rewriter.rewrite(fun, rewriter, state)

Util.is_defd?(module, fun_name, arity) ->
defd_name = Util.defd_name(fun_name)
Expand Down Expand Up @@ -740,7 +740,7 @@ defmodule Dx.Defd.Compiler do
end)

rewriter = @rewriters[module] ->
rewriter.rewrite(fun, state)
Dx.Defd.Rewriter.rewrite(fun, rewriter, state)

Util.is_defd?(module, fun_name, arity) ->
defd_name = Util.defd_name(fun_name)
Expand Down Expand Up @@ -793,34 +793,34 @@ defmodule Dx.Defd.Compiler do
""")
end

def maybe_load_scope({:ok, module}, state) when is_atom(module) do
def maybe_load_scope({:ok, module}, true, state) when is_atom(module) do
quote do
Dx.Scope.lookup(Dx.Scope.all(unquote(module)), unquote(state.eval_var))
end
|> Loader.add(state)
end

def maybe_load_scope({:ok, var}, state) when is_var(var) do
def maybe_load_scope({:ok, var}, _convert_atoms_to_scopes?, state) when is_var(var) do
quote do
Dx.Scope.maybe_lookup(unquote(var), unquote(state.eval_var))
end
|> Loader.add(state)
end

def maybe_load_scope({:ok, {:%{}, _meta, [{:__struct__, Dx.Scope} | _]} = ast}, state) do
def maybe_load_scope({:ok, {:%{}, _meta, [{:__struct__, Dx.Scope} | _]} = ast}, _, state) do
quote do
Dx.Scope.lookup(unquote(ast), unquote(state.eval_var))
end
|> Loader.add(state)
end

def maybe_load_scope({:ok, ast}, state) do
def maybe_load_scope({:ok, ast}, _, state) do
{{:ok, ast}, state}
end

# for undefined variables
def maybe_load_scope(other, state) do
{other, state}
def maybe_load_scope(other, _, state) do
{{:ok, other}, state}
end

def add_scope_loader_for({:ok, ast}, state) do
Expand Down
165 changes: 164 additions & 1 deletion lib/dx/defd/ext.ex
Original file line number Diff line number Diff line change
@@ -1,5 +1,168 @@
defmodule Dx.Defd.Ext do
@moduledoc false
@moduledoc """
Used to make existing libraries compatible with `Dx.Defd`.
## Usage
```elixir
defmodule MyExt do
use Dx.Defd.Ext
@impl true
def __fun_info(fun_name, arity, args) do
%{kernel?: true}
end
end
```
## Options
Return a map with the following keys:
- `args` - list or map of argument indexes mapping to `ArgInfo` structs
- `kernel?` - whether the function is a kernel function
- `scopable` - whether the function is scopable
- `convert_atoms_to_scopes?` - whether to convert atoms to scopes
- `warn_not_ok` - warning message to display when the function returns `{:error, _}`
- `warn_always` - warning message to display when the function returns `:ok`
"""

alias __MODULE__
alias Ext.ArgInfo
alias Ext.FunInfo

defmodule ArgInfo do
@moduledoc false

defstruct kernel?: false,
scopable: false,
fn: false

@type t() :: %__MODULE__{
kernel?: boolean(),
scopable: boolean(),
fn: boolean()
}

def new!(%ArgInfo{} = arg_info), do: arg_info
def new!(field) when is_atom(field), do: new!(%{field => true})
def new!(fields) when is_list(fields), do: fields |> Enum.map(&field!/1) |> new!()
def new!(fields), do: struct!(ArgInfo, fields)

defp field!(field) when is_atom(field), do: {field, true}
defp field!({field, value}) when is_atom(field), do: {field, value}
end

defmodule FunInfo do
@moduledoc false

defstruct module: nil,
fun_name: nil,
arity: nil,
args: [],
kernel?: false,
scopable: false,
convert_atoms_to_scopes?: false,
warn_not_ok: nil,
warn_always: nil

@type t() :: %__MODULE__{
module: atom(),
fun_name: atom(),
arity: non_neg_integer(),
args: %{non_neg_integer() => Dx.Defd.Ext.ArgInfo.t()} | list(Dx.Defd.Ext.ArgInfo.t()),
kernel?: boolean(),
scopable: boolean(),
convert_atoms_to_scopes?: boolean(),
warn_not_ok: binary() | nil,
warn_always: binary() | nil
}

@doc """
Creates a new `FunInfo` struct.
## Examples
iex> new!(%{args: [:scopable]}, %{arity: 2})
%FunInfo{arity: 2, args: [%ArgInfo{scopable: true}, %ArgInfo{}]}
iex> new!(%{args: %{0 => :scopable}}, %{arity: 2})
%FunInfo{arity: 2, args: [%ArgInfo{scopable: true}, %ArgInfo{}]}
"""
def new!(%FunInfo{} = fun_info, extra_fields) do
fun_info
|> struct!(extra_fields)
|> args!()
end

def new!(fields, extra_fields) do
FunInfo
|> struct!(fields)
|> struct!(extra_fields)
|> args!()
end

defp args!(%FunInfo{args: args} = fun_info) when is_map(args) do
Enum.each(args, fn {i, arg_info} ->
if i >= fun_info.arity do
raise ArgumentError,
"argument index must be less than the function's arity #{fun_info.arity}." <>
" Got #{i} => #{inspect(arg_info)}"
end
end)

args =
0..(fun_info.arity - 1)
|> Enum.map(fn i ->
case Map.fetch(args, i) do
{:ok, arg_info} -> ArgInfo.new!(arg_info)
:error -> %ArgInfo{}
end
end)

%{fun_info | args: args}
end

defp args!(%FunInfo{args: args} = fun_info) when is_list(args) do
if length(args) > fun_info.arity do
raise ArgumentError,
"number of arguments must be within the function's arity #{fun_info.arity}." <>
" Got #{length(args)} arguments for #{inspect(fun_info.module)}.#{fun_info.fun_name}/#{fun_info.arity}"
end

given_args =
args
|> Enum.map(&ArgInfo.new!/1)

non_given_args =
length(given_args)..(fun_info.arity - 1)
|> Enum.map(fn _i -> %ArgInfo{} end)

args = given_args ++ non_given_args

%{fun_info | args: args}
end

defp args!(_fun_info), do: raise(ArgumentError, "args must be a map or a list")
end

defmacro __using__(_opts) do
quote do
@behaviour Dx.Defd.Ext

alias Dx.Defd.Ext.ArgInfo
alias Dx.Defd.Ext.FunInfo

import Dx.Defd.Ext
end
end

@doc """
This callback is used to provide information about a function to `Dx.Defd`.
"""
@callback __fun_info(atom(), non_neg_integer(), list()) :: __MODULE__.FunInfo.t()

@optional_callbacks __fun_info: 3

alias Dx.Defd.Util

Expand Down
Loading

0 comments on commit 5070352

Please sign in to comment.