Skip to content

Commit

Permalink
Expand :|| and :! (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper authored Nov 12, 2024
1 parent 2773941 commit 97485cf
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 69 deletions.
66 changes: 52 additions & 14 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ defmodule Charms.Defm.Expander do
{l, state, env} = expand(l, state, env)
{init, state, env} = expand(init, state, env)
result_t = MLIR.Value.type(init)
{list_term_ptr, state} = uniq_mlir_var(state, l)
{list_term_ptr, state} = uniq_mlir_var(l, state)
tail_ptr = uniq_mlir_var()
head_ptr = uniq_mlir_var()

Expand Down Expand Up @@ -470,7 +470,7 @@ defmodule Charms.Defm.Expander do

defp expand_intrinsics(loc, module, fun, args, state, env) do
{args, state, env} = expand(args, state, env)
{params, state} = uniq_mlir_params(state, args)
{params, state} = uniq_mlir_params(args, state)

case v =
module.handle_intrinsic(fun, params, args,
Expand Down Expand Up @@ -692,13 +692,21 @@ defmodule Charms.Defm.Expander do
{left, state, env} = expand(left, state, env)
{right, state, env} = expand(right, state, env)
loc = MLIR.Location.from_env(env)
{params, state} = uniq_mlir_params(state, [left, right])
{params, state} = uniq_mlir_params([left, right], state)

{Charms.Prelude.handle_intrinsic(fun, params, [left, right],
ctx: state.mlir.ctx,
block: state.mlir.blk,
loc: loc
), state, env}
try do
{Charms.Prelude.handle_intrinsic(fun, params, [left, right],
ctx: state.mlir.ctx,
block: state.mlir.blk,
loc: loc
), state, env}
rescue
e ->
raise_compile_error(
env,
"Failed to expand prelude intrinsic #{fun}: #{Exception.message(e)}"
)
end
end

## =/2
Expand Down Expand Up @@ -1043,6 +1051,20 @@ defmodule Charms.Defm.Expander do

v =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
cond_type = MLIR.Value.type(condition)
bool_type = Type.i1(ctx: state.mlir.ctx)
# Ensure the condition is a i1, if not compare it to 0
condition =
if MLIR.equal?(cond_type, bool_type) do
condition
else
zero =
Arith.constant(value: Attribute.integer(cond_type, 0), loc: loc) >>> cond_type

Arith.cmpi(condition, zero, predicate: Arith.cmp_i_predicate(:sgt), loc: loc) >>>
Type.i1()
end

b =
block _true() do
ret_t =
Expand All @@ -1065,6 +1087,22 @@ defmodule Charms.Defm.Expander do
{v, state, env}
end

defp expand_macro(_meta, Kernel, :!, [value], _callback, state, env) do
{value, state, env} = expand(value, state, env)
type = MLIR.Value.type(value)
{value, state} = uniq_mlir_var(value, state)
{type, state} = uniq_mlir_var(type, state)

{not_value, state, env} =
quote do
one = const 1 :: unquote(type)
value arith.xori(unquote(value), one) :: unquote(type)
end
|> expand(state, env)

{List.last(not_value), state, env}
end

defp expand_macro(_meta, Charms.Defm, :while, [expr, [do: body]], _callback, state, env) do
v =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
Expand Down Expand Up @@ -1246,7 +1284,7 @@ defmodule Charms.Defm.Expander do
defp expand_remote(_meta, Kernel, fun, args, state, env) when fun in @prelude_intrinsics do
loc = MLIR.Location.from_env(env)
{args, state, env} = expand(args, state, env)
{params, state} = uniq_mlir_params(state, args)
{params, state} = uniq_mlir_params(args, state)

{Charms.Prelude.handle_intrinsic(fun, params, args,
ctx: state.mlir.ctx,
Expand Down Expand Up @@ -1282,7 +1320,7 @@ defmodule Charms.Defm.Expander do
if function_exported?(MLIR.Type, fun, 1) do
{apply(MLIR.Type, fun, [[ctx: state.mlir.ctx]]), state, env}
else
{params, state} = uniq_mlir_params(state, args)
{params, state} = uniq_mlir_params(args, state)

case i =
Charms.Prelude.handle_intrinsic(fun, params, args,
Expand Down Expand Up @@ -1369,7 +1407,7 @@ defmodule Charms.Defm.Expander do

defp beam_env_from_defm!(env, state) do
if e = state.mlir.enif_env do
uniq_mlir_var(state, e)
uniq_mlir_var(e, state)
else
raise_compile_error(env, "must be a defm with beam env as the first argument")
end
Expand All @@ -1380,14 +1418,14 @@ defmodule Charms.Defm.Expander do
Macro.var(:"#{@var_prefix}#{System.unique_integer([:positive])}", nil)
end

defp uniq_mlir_var(state, val) do
defp uniq_mlir_var(val, state) do
uniq_mlir_var() |> then(&{&1, put_mlir_var(state, &1, val)})
end

defp uniq_mlir_params(state, args) when is_list(args) do
defp uniq_mlir_params(args, state) when is_list(args) do
for param <- args, reduce: {[], state} do
{params, %{mlir: _} = state} ->
{param, %{mlir: _} = state} = uniq_mlir_var(state, param)
{param, %{mlir: _} = state} = uniq_mlir_var(param, state)
{params ++ [param], state}
end
end
Expand Down
18 changes: 15 additions & 3 deletions lib/charms/prelude.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ defmodule Charms.Prelude do
use Charms.Intrinsic
alias Beaver.MLIR.Dialect.{Arith, Func}
@enif_functions Beaver.ENIF.functions()
@binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :&&, :*]
@binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :&&, :||, :*]

defp constant_of_same_type(i, v, opts) do
mlir ctx: opts[:ctx], block: opts[:block] do
t = MLIR.CAPI.mlirValueGetType(v)
Arith.constant(value: Attribute.integer(t, i)) >>> t

if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do
Arith.constant(value: Attribute.integer(t, i)) >>> t
else
raise ArgumentError, "Not an integer type for constant, #{to_string(t)}"
end
end
end

Expand All @@ -21,7 +26,11 @@ defmodule Charms.Prelude do
i

i when is_integer(i) ->
Arith.constant(value: Attribute.integer(t, i)) >>> t
if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do
Arith.constant(value: Attribute.integer(t, i)) >>> t
else
raise ArgumentError, "Not an integer type, #{to_string(t)}"
end
end
end
end
Expand Down Expand Up @@ -53,6 +62,9 @@ defmodule Charms.Prelude do
:&& ->
Arith.andi(operands) >>> type

:|| ->
Arith.ori(operands) >>> type

:* ->
Arith.muli(operands) >>> type
end
Expand Down
125 changes: 73 additions & 52 deletions test/defm_test.exs
Original file line number Diff line number Diff line change
@@ -1,3 +1,45 @@
defmodule AddTwoInt do
use Charms, init: false
alias Charms.{Pointer, Term}

defm add_or_error_with_cond_br(env, a, b, error) :: Term.t() do
ptr_a = Pointer.allocate(i32())
ptr_b = Pointer.allocate(i32())

arg_err =
block do
func.return(error)
end

cond_br enif_get_int(env, a, ptr_a) != 0 do
cond_br 0 != enif_get_int(env, b, ptr_b) do
a = Pointer.load(i32(), ptr_a)
b = Pointer.load(i32(), ptr_b)
sum = value llvm.add(a, b) :: i32()
term = enif_make_int(env, sum)
func.return(term)
else
^arg_err
end
else
^arg_err
end
end

defm add(env, a, b) :: Term.t() do
ptr_a = Pointer.allocate(i32())
ptr_b = Pointer.allocate(i32())

if !enif_get_int(env, a, ptr_a) || !enif_get_int(env, b, ptr_b) do
enif_make_badarg(env)
else
a = Pointer.load(i32(), ptr_a)
b = Pointer.load(i32(), ptr_b)
enif_make_int(env, a + b)
end
end
end

defmodule DefmTest do
use ExUnit.Case, async: true

Expand All @@ -6,28 +48,32 @@ defmodule DefmTest do
end

test "invalid return of absent alias" do
assert_raise CompileError, "test/defm_test.exs:13: invalid return type", fn ->
defmodule InvalidRet do
use Charms

defm my_function(env, arg1, arg2) :: Invalid.t() do
func.return(arg2)
end
end
end
assert_raise CompileError,
"test/defm_test.exs:#{__ENV__.line + 5}: invalid return type",
fn ->
defmodule InvalidRet do
use Charms

defm my_function(env, arg1, arg2) :: Invalid.t() do
func.return(arg2)
end
end
end
end

test "invalid arg of absent alias" do
assert_raise CompileError, "test/defm_test.exs:26: invalid argument type #2", fn ->
defmodule InvalidRet do
use Charms
alias Charms.Term

defm my_function(env, arg1 :: Pointer.t(), arg2) :: Term.t() do
func.return(arg2)
end
end
end
assert_raise CompileError,
"test/defm_test.exs:#{__ENV__.line + 6}: invalid argument type #2",
fn ->
defmodule InvalidRet do
use Charms
alias Charms.Term

defm my_function(env, arg1 :: Pointer.t(), arg2) :: Term.t() do
func.return(arg2)
end
end
end
end

test "only env defm is exported" do
Expand All @@ -39,41 +85,14 @@ defmodule DefmTest do
end

test "add two integers" do
defmodule AddTwoInt do
use Charms, init: false
alias Charms.{Pointer, Term}

defm add(env, a, b, error) :: Term.t() do
ptr_a = Pointer.allocate(i64())
ptr_b = Pointer.allocate(i64())

arg_err =
block do
func.return(error)
end

cond_br enif_get_int64(env, a, ptr_a) != 0 do
cond_br 0 != enif_get_int64(env, b, ptr_b) do
a = Pointer.load(i64(), ptr_a)
b = Pointer.load(i64(), ptr_b)
sum = value llvm.add(a, b) :: i64()
term = enif_make_int64(env, sum)
func.return(term)
else
^arg_err
end
else
^arg_err
end
end
end

assert {:ok, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int)
assert {:cached, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int)
engine = Charms.JIT.engine(:add_int)
assert String.starts_with?(AddTwoInt.__ir__(), "ML\xefR")
assert AddTwoInt.add(1, 2, :arg_err).(engine) == 3
assert AddTwoInt.add(1, "", :arg_err).(engine) == :arg_err
assert AddTwoInt.add(1, 2).(engine) == 3
assert_raise ArgumentError, fn -> AddTwoInt.add(1, "2").(engine) end
assert AddTwoInt.add_or_error_with_cond_br(1, 2, :arg_err).(engine) == 3
assert AddTwoInt.add_or_error_with_cond_br(1, "", :arg_err).(engine) == :arg_err
assert :ok = Charms.JIT.destroy(:add_int)
end

Expand Down Expand Up @@ -109,8 +128,10 @@ defmodule DefmTest do
end

test "undefined remote function" do
line = __ENV__.line

assert_raise CompileError,
"test/defm_test.exs:119: Failed to expand macro Elixir.DifferentCalls.something/1: test/defm_test.exs:119: function something not found in module DifferentCalls",
~r"Failed to expand macro Elixir.DifferentCalls.something/1.+function something not found in module DifferentCalls",
fn ->
defmodule Undefined do
use Charms
Expand All @@ -124,7 +145,7 @@ defmodule DefmTest do

test "wrong return type remote function" do
assert_raise CompileError,
"test/defm_test.exs:133: mismatch type in invocation: f32 vs. i64",
~r"mismatch type in invocation: f32 vs. i64",
fn ->
defmodule WrongReturnType do
use Charms
Expand Down

0 comments on commit 97485cf

Please sign in to comment.