Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function call return type infererence #41

Merged
merged 12 commits into from
Oct 20, 2024
6 changes: 2 additions & 4 deletions bench/enif_merge_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ defmodule ENIFMergeSort do
m = result_at(m, 0)
do_sort(arr, l, m)
do_sort(arr, m + 1, r)
call SortUtil.merge(arr, l, m, r)
SortUtil.merge(arr, l, m, r)
end

func.return
end

@err %ArgumentError{message: "list expected"}
Expand All @@ -25,7 +23,7 @@ defmodule ENIFMergeSort do
Pointer.store(list, movable_list_ptr)
len = Pointer.load(i32(), len_ptr)
arr = Pointer.allocate(Term.t(), len)
call ENIFTimSort.copy_terms(env, movable_list_ptr, arr)
SortUtil.copy_terms(env, movable_list_ptr, arr)
zero = const 0 :: i32()
do_sort(arr, zero, len - 1)
enif_make_list_from_array(env, arr, len)
Expand Down
28 changes: 3 additions & 25 deletions bench/enif_quick_sort.ex
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defmodule ENIFQuickSort do
@moduledoc false
use Charms
alias Charms.{Pointer, Term, Env}
alias Charms.{Pointer, Term}

defm swap(a :: Pointer.t(), b :: Pointer.t()) do
tmp = Pointer.allocate(Term.t())
Expand Down Expand Up @@ -36,34 +36,12 @@ defmodule ENIFQuickSort do

defm do_sort(arr :: Pointer.t(), low :: i32(), high :: i32()) do
if low < high do
pi = call partition(arr, low, high) :: i32()
pi = partition(arr, low, high)
do_sort(arr, low, pi - 1)
do_sort(arr, pi + 1, high)
end
end

defm copy_terms(env :: Env.t(), movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do
head = Pointer.allocate(Term.t())
zero = const 0 :: i32()
i_ptr = Pointer.allocate(i32())
Pointer.store(zero, i_ptr)

while_loop(
enif_get_list_cell(
env,
Pointer.load(Term.t(), movable_list_ptr),
head,
movable_list_ptr
) > 0
) do
head_val = Pointer.load(Term.t(), head)
i = Pointer.load(i32(), i_ptr)
ith_term_ptr = Pointer.element_ptr(Term.t(), arr, i)
Pointer.store(head_val, ith_term_ptr)
Pointer.store(i + 1, i_ptr)
end
end

@err %ArgumentError{message: "list expected"}
defm sort(env, list) :: Term.t() do
len_ptr = Pointer.allocate(i32())
Expand All @@ -73,7 +51,7 @@ defmodule ENIFQuickSort do
Pointer.store(list, movable_list_ptr)
len = Pointer.load(i32(), len_ptr)
arr = Pointer.allocate(Term.t(), len)
copy_terms(env, movable_list_ptr, arr)
SortUtil.copy_terms(env, movable_list_ptr, arr)
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
zero = const 0 :: i32()
do_sort(arr, zero, len - 1)
enif_make_list_from_array(env, arr, len)
Expand Down
36 changes: 7 additions & 29 deletions bench/enif_tim_sort.ex
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defmodule ENIFTimSort do
@moduledoc false
use Charms
alias Charms.{Pointer, Term, Env}
alias Charms.{Pointer, Term}

defm insertion_sort(arr :: Pointer.t(), left :: i32(), right :: i32()) do
start_i = left + 1
Expand All @@ -14,7 +14,7 @@ defmodule ENIFTimSort do
j_ptr = Pointer.allocate(i32())
Pointer.store(i - 1, j_ptr)

while_loop(
while(
Pointer.load(i32(), j_ptr) >= left &&
Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), arr, Pointer.load(i32(), j_ptr))) >
temp
Expand All @@ -40,7 +40,7 @@ defmodule ENIFTimSort do
zero = const 0 :: i32()
Pointer.store(zero, i_ptr)

while_loop(Pointer.load(i32(), i_ptr) < n) do
while Pointer.load(i32(), i_ptr) < n do
i = Pointer.load(i32(), i_ptr)
min = value arith.minsi(i + run - 1, n - 1) :: i32()
insertion_sort(arr, i, min)
Expand All @@ -50,20 +50,20 @@ defmodule ENIFTimSort do
size_ptr = Pointer.allocate(i32())
Pointer.store(run, size_ptr)

while_loop(Pointer.load(i32(), size_ptr) < n) do
while Pointer.load(i32(), size_ptr) < n do
size = Pointer.load(i32(), size_ptr)

left_ptr = Pointer.allocate(i32())
Pointer.store(zero, left_ptr)

while_loop(Pointer.load(i32(), left_ptr) < n) do
while Pointer.load(i32(), left_ptr) < n do
left = Pointer.load(i32(), left_ptr)
mid = left + size - 1
right = op arith.minsi(left + 2 * size - 1, n - 1) :: i32()
right = result_at(right, 0)

if mid < right do
call SortUtil.merge(arr, left, mid, right)
SortUtil.merge(arr, left, mid, right)
end

Pointer.store(left + 2 * size, left_ptr)
Expand All @@ -73,28 +73,6 @@ defmodule ENIFTimSort do
end
end

defm copy_terms(env :: Env.t(), movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do
head = Pointer.allocate(Term.t())
zero = const 0 :: i32()
i_ptr = Pointer.allocate(i32())
Pointer.store(zero, i_ptr)

while_loop(
enif_get_list_cell(
env,
Pointer.load(Term.t(), movable_list_ptr),
head,
movable_list_ptr
) > 0
) do
head_val = Pointer.load(Term.t(), head)
i = Pointer.load(i32(), i_ptr)
ith_term_ptr = Pointer.element_ptr(Term.t(), arr, i)
Pointer.store(head_val, ith_term_ptr)
Pointer.store(i + 1, i_ptr)
end
end

@err %ArgumentError{message: "list expected"}
defm sort(env, list) :: Term.t() do
len_ptr = Pointer.allocate(i32())
Expand All @@ -104,7 +82,7 @@ defmodule ENIFTimSort do
Pointer.store(list, movable_list_ptr)
len = Pointer.load(i32(), len_ptr)
arr = Pointer.allocate(Term.t(), len)
copy_terms(env, movable_list_ptr, arr)
SortUtil.copy_terms(env, movable_list_ptr, arr)
tim_sort(arr, len)
enif_make_list_from_array(env, arr, len)
else
Expand Down
32 changes: 26 additions & 6 deletions bench/sort_util.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,30 @@ defmodule SortUtil do
use Charms
alias Charms.{Pointer, Term}

defm copy_terms(env, movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do
head = Pointer.allocate(Term.t())
zero = const 0 :: i32()
i_ptr = Pointer.allocate(i32())
Pointer.store(zero, i_ptr)

while(
enif_get_list_cell(
env,
Pointer.load(Term.t(), movable_list_ptr),
head,
movable_list_ptr
) > 0
) do
head_val = Pointer.load(Term.t(), head)
i = Pointer.load(i32(), i_ptr)
ith_term_ptr = Pointer.element_ptr(Term.t(), arr, i)
Pointer.store(head_val, ith_term_ptr)
Pointer.store(i + 1, i_ptr)
end
end

defm merge(arr :: Pointer.t(), l :: i32(), m :: i32(), r :: i32()) do
n1 = m - l + 1 - 1 + 1
n1 = m - l + 1
n2 = r - m

left_temp = Pointer.allocate(Term.t(), n1)
Expand All @@ -30,7 +52,7 @@ defmodule SortUtil do
Pointer.store(zero, j_ptr)
Pointer.store(l, k_ptr)

while_loop(Pointer.load(i32(), i_ptr) < n1 && Pointer.load(i32(), j_ptr) < n2) do
while Pointer.load(i32(), i_ptr) < n1 && Pointer.load(i32(), j_ptr) < n2 do
i = Pointer.load(i32(), i_ptr)
j = Pointer.load(i32(), j_ptr)
k = Pointer.load(i32(), k_ptr)
Expand All @@ -57,7 +79,7 @@ defmodule SortUtil do
Pointer.store(k + 1, k_ptr)
end

while_loop(Pointer.load(i32(), i_ptr) < n1) do
while Pointer.load(i32(), i_ptr) < n1 do
i = Pointer.load(i32(), i_ptr)
k = Pointer.load(i32(), k_ptr)

Expand All @@ -70,7 +92,7 @@ defmodule SortUtil do
Pointer.store(k + 1, k_ptr)
end

while_loop(Pointer.load(i32(), j_ptr) < n2) do
while Pointer.load(i32(), j_ptr) < n2 do
j = Pointer.load(i32(), j_ptr)
k = Pointer.load(i32(), k_ptr)

Expand All @@ -82,7 +104,5 @@ defmodule SortUtil do
Pointer.store(j + 1, j_ptr)
Pointer.store(k + 1, k_ptr)
end

func.return
end
end
8 changes: 4 additions & 4 deletions bench/vec_add_int_list.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ defmodule AddTwoIntVec do
end

defm add(env, a, b, error) :: Term.t() do
v1 = call load_list(env, a) :: SIMD.t(i32(), 8)
v2 = call load_list(env, b) :: SIMD.t(i32(), 8)
v1 = load_list(env, a)
v2 = load_list(env, b)
v = arith.addi(v1, v2)
start = const 0 :: i32()

Expand All @@ -40,8 +40,8 @@ defmodule AddTwoIntVec do
end

defm dummy_load_no_make(env, a, b, error) :: Term.t() do
v1 = call load_list(env, a) :: SIMD.t(i32(), 8)
v2 = call load_list(env, b) :: SIMD.t(i32(), 8)
v1 = load_list(env, a)
v2 = load_list(env, b)
func.return(a)
end

Expand Down
Loading
Loading