Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow reference loop between modules #39

Merged
merged 4 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 1 addition & 85 deletions bench/enif_merge_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,98 +3,14 @@ defmodule ENIFMergeSort do
use Charms
alias Charms.{Pointer, Term}

defm merge(arr :: Pointer.t(), l :: i32(), m :: i32(), r :: i32()) do
n1 = m - l + 1
n2 = r - m

left_temp = Pointer.allocate(Term.t(), n1)
right_temp = Pointer.allocate(Term.t(), n2)

for_loop {element, i} <- {Term.t(), Pointer.element_ptr(Term.t(), arr, l), n1} do
i = op index.casts(i) :: i32()
i = result_at(i, 0)
Pointer.store(element, Pointer.element_ptr(Term.t(), left_temp, i))
end

for_loop {element, j} <- {Term.t(), Pointer.element_ptr(Term.t(), arr, m + 1), n2} do
j = op index.casts(j) :: i32()
j = result_at(j, 0)
Pointer.store(element, Pointer.element_ptr(Term.t(), right_temp, j))
end

i_ptr = Pointer.allocate(i32())
j_ptr = Pointer.allocate(i32())
k_ptr = Pointer.allocate(i32())

zero = const 0 :: i32()
Pointer.store(zero, i_ptr)
Pointer.store(zero, j_ptr)
Pointer.store(l, k_ptr)

while_loop(Pointer.load(i32(), i_ptr) < n1 && Pointer.load(i32(), j_ptr) < n2) do
i = Pointer.load(i32(), i_ptr)
j = Pointer.load(i32(), j_ptr)
k = Pointer.load(i32(), k_ptr)

left_term = Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i))
right_term = Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j))

if enif_compare(left_term, right_term) <= 0 do
Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)),
Pointer.element_ptr(Term.t(), arr, k)
)

Pointer.store(i + 1, i_ptr)
else
Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j)),
Pointer.element_ptr(Term.t(), arr, k)
)

Pointer.store(j + 1, j_ptr)
end

Pointer.store(k + 1, k_ptr)
end

while_loop(Pointer.load(i32(), i_ptr) < n1) do
i = Pointer.load(i32(), i_ptr)
k = Pointer.load(i32(), k_ptr)

Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)),
Pointer.element_ptr(Term.t(), arr, k)
)

Pointer.store(i + 1, i_ptr)
Pointer.store(k + 1, k_ptr)
end

while_loop(Pointer.load(i32(), j_ptr) < n2) do
j = Pointer.load(i32(), j_ptr)
k = Pointer.load(i32(), k_ptr)

Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j)),
Pointer.element_ptr(Term.t(), arr, k)
)

Pointer.store(j + 1, j_ptr)
Pointer.store(k + 1, k_ptr)
end

func.return
end

defm do_sort(arr :: Pointer.t(), l :: i32(), r :: i32()) do
if l < r do
two = const 2 :: i32()
m = op arith.divsi(l + r, two) :: i32()
m = result_at(m, 0)
do_sort(arr, l, m)
do_sort(arr, m + 1, r)
merge(arr, l, m, r)
call SortUtil.merge(arr, l, m, r)
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
end

func.return
Expand Down
2 changes: 1 addition & 1 deletion bench/enif_tim_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ defmodule ENIFTimSort do
right = result_at(right, 0)

if mid < right do
call ENIFMergeSort.merge(arr, left, mid, right)
call SortUtil.merge(arr, left, mid, right)
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
end

Pointer.store(left + 2 * size, left_ptr)
Expand Down
88 changes: 88 additions & 0 deletions bench/sort_util.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
defmodule SortUtil do
use Charms
alias Charms.{Pointer, Term}

defm merge(arr :: Pointer.t(), l :: i32(), m :: i32(), r :: i32()) do
n1 = m - l + 1 - 1 + 1
n2 = r - m

left_temp = Pointer.allocate(Term.t(), n1)
right_temp = Pointer.allocate(Term.t(), n2)

for_loop {element, i} <- {Term.t(), Pointer.element_ptr(Term.t(), arr, l), n1} do
i = op index.casts(i) :: i32()
i = result_at(i, 0)
Pointer.store(element, Pointer.element_ptr(Term.t(), left_temp, i))
end

for_loop {element, j} <- {Term.t(), Pointer.element_ptr(Term.t(), arr, m + 1), n2} do
j = op index.casts(j) :: i32()
j = result_at(j, 0)
Pointer.store(element, Pointer.element_ptr(Term.t(), right_temp, j))
end

i_ptr = Pointer.allocate(i32())
j_ptr = Pointer.allocate(i32())
k_ptr = Pointer.allocate(i32())

zero = const 0 :: i32()
Pointer.store(zero, i_ptr)
Pointer.store(zero, j_ptr)
Pointer.store(l, k_ptr)

while_loop(Pointer.load(i32(), i_ptr) < n1 && Pointer.load(i32(), j_ptr) < n2) do
i = Pointer.load(i32(), i_ptr)
j = Pointer.load(i32(), j_ptr)
k = Pointer.load(i32(), k_ptr)

left_term = Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i))
right_term = Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j))

if enif_compare(left_term, right_term) <= 0 do
Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)),
Pointer.element_ptr(Term.t(), arr, k)
)

Pointer.store(i + 1, i_ptr)
else
Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j)),
Pointer.element_ptr(Term.t(), arr, k)
)

Pointer.store(j + 1, j_ptr)
end

Pointer.store(k + 1, k_ptr)
end

while_loop(Pointer.load(i32(), i_ptr) < n1) do
i = Pointer.load(i32(), i_ptr)
k = Pointer.load(i32(), k_ptr)

Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)),
Pointer.element_ptr(Term.t(), arr, k)
)

Pointer.store(i + 1, i_ptr)
Pointer.store(k + 1, k_ptr)
end

while_loop(Pointer.load(i32(), j_ptr) < n2) do
j = Pointer.load(i32(), j_ptr)
k = Pointer.load(i32(), k_ptr)

Pointer.store(
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j)),
Pointer.element_ptr(Term.t(), arr, k)
)

Pointer.store(j + 1, j_ptr)
Pointer.store(k + 1, k_ptr)
end

func.return
end
end
76 changes: 38 additions & 38 deletions lib/charms.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,64 @@ defmodule Charms do

- We need a explicit `call` in function call because the `::` special form has a parser priority that is too low so a `call` macro is introduced to ensure proper scope.
- Being variadic, intrinsic must be called with the module name. `import` doesn't work with intrinsic functions while `alias` is supported.

## Glossary of modules

- `Charms`: the top level macros `defm` and `use Charms`
- `Charms.Defm`: the `defm` DSL syntax and special forms
- `Charms.Defm.Definition`: functions to define and compile `defm` functions to MLIR
- `Charms.Intrinsic`: the behavior used to define and compile intrinsic functions
"""

defmacro __using__(opts) do
quote do
import Charms
use Beaver
import Beaver.MLIR.Type

@doc false
def __use_ir__, do: nil
@before_compile Charms
Module.register_attribute(__MODULE__, :defm, accumulate: true)
Module.register_attribute(__MODULE__, :init_at_fun_call, persist: true)
@init_at_fun_call Keyword.get(unquote(opts), :init, true)
end
end

defmacro __before_compile__(_env) do
defmacro __before_compile__(env) do
defm_decls = Module.get_attribute(env.module, :defm) || []
{ir, referenced_modules} = defm_decls |> Enum.reverse() |> Charms.Defm.Definition.compile()

# create uses in Elixir, to disallow loop reference
r =
for r <- referenced_modules, r != env.module do
quote do
unquote(r).__use_ir__
end
end

quote do
{ir, referenced_modules} = @defm |> Enum.reverse() |> Charms.Defm.compile_definitions()
@ir ir
@referenced_modules referenced_modules
@ir unquote(ir)
@referenced_modules unquote(referenced_modules)
unquote_splicing(r)

@ir_hash [
:erlang.phash2(@ir)
| for r <- @referenced_modules, r != __MODULE__ do
r.__ir_digest__()
end
]
|> List.flatten()

@doc false
def __ir__ do
@ir
end

@doc false
def __ir_digest__ do
@ir_hash
end

@doc false
def referenced_modules do
@referenced_modules
Expand All @@ -61,38 +93,6 @@ defmodule Charms do
define a function that can be JIT compiled
"""
defmacro defm(call, body \\ []) do
{call, ret_types} = Charms.Defm.decompose_call_with_return_type(call)

call = Charms.Defm.normalize_call(call)
{name, args} = Macro.decompose_call(call)

{:ok, env} =
__CALLER__ |> Macro.Env.define_import([], Charms.Defm, warn: false, only: :macros)

[_enif_env | invoke_args] = args

invoke_args =
for {:"::", _, [a, _t]} <- invoke_args do
a
end

quote do
@defm unquote(Macro.escape({env, {call, ret_types, body}}))
def unquote(name)(unquote_splicing(invoke_args)) do
mfa = {unquote(env.module), unquote(name), unquote(invoke_args)}

cond do
@init_at_fun_call ->
{_, %Charms.JIT{engine: engine} = jit} = Charms.JIT.init(__MODULE__)
Charms.JIT.invoke(engine, mfa)

(engine = Charms.JIT.engine(__MODULE__)) != nil ->
Charms.JIT.invoke(engine, mfa)

true ->
&Charms.JIT.invoke(&1, mfa)
end
end
end
Charms.Defm.Definition.declare(__CALLER__, call, body)
end
end
Loading
Loading