Skip to content

Commit

Permalink
Add Charms.Constant.from_literal/5 (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper authored Dec 14, 2024
1 parent 123ab77 commit 37abfda
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 63 deletions.
29 changes: 29 additions & 0 deletions lib/charms/constant.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
defmodule Charms.Constant do
@moduledoc false
use Beaver
alias Beaver.MLIR.Dialect.{Arith, Index}

def from_literal(literal, %MLIR.Value{} = v, ctx, blk, loc) do
t = MLIR.CAPI.mlirValueGetType(v)
from_literal(literal, t, ctx, blk, loc)
end

def from_literal(literal, %MLIR.Type{} = t, ctx, blk, loc) do
mlir ctx: ctx, blk: blk do
cond do
MLIR.Type.integer?(t) ->
Arith.constant(value: Attribute.integer(t, literal), loc: loc) >>> t

MLIR.Type.float?(t) ->
Arith.constant(value: Attribute.float(t, literal), loc: loc) >>> t

MLIR.Type.index?(t) ->
Index.constant(value: Attribute.index(literal), loc: loc) >>> t

true ->
loc = Beaver.Deferred.create(loc, ctx)
raise CompileError, Charms.Diagnostic.meta_from_loc(loc) ++ [description: "Not a supported type for constant, #{to_string(t)}"]
end
end
end
end
27 changes: 15 additions & 12 deletions lib/charms/defm/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,21 @@ defmodule Charms.Defm.Definition do
"""
def compile(definitions) when is_list(definitions) do
ctx = MLIR.Context.create()
{res, msg} = MLIR.Context.with_diagnostics(
ctx,
fn ->
try do
{:ok, do_compile(ctx, definitions)}
rescue
err ->
{:error, err}
end
end,
fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end
)

{res, msg} =
MLIR.Context.with_diagnostics(
ctx,
fn ->
try do
{:ok, do_compile(ctx, definitions)}
rescue
err ->
{:error, err}
end
end,
fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end
)

case {res, msg} do
{{:ok, {mlir, mods}}, nil} ->
MLIR.Context.destroy(ctx)
Expand Down
12 changes: 1 addition & 11 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1210,17 +1210,7 @@ defmodule Charms.Defm.Expander do
value =
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
loc = MLIR.Location.from_env(env)

cond do
MLIR.CAPI.mlirTypeIsAInteger(type) |> Beaver.Native.to_term() ->
Arith.constant(value: Attribute.integer(type, value), loc: loc) >>> type

MLIR.CAPI.mlirTypeIsAFloat(type) |> Beaver.Native.to_term() ->
Arith.constant(value: Attribute.float(type, value), loc: loc) >>> type

true ->
raise_compile_error(env, "Unsupported type for const macro: #{to_string(type)}")
end
Charms.Constant.from_literal(value, type, state.mlir.ctx, state.mlir.blk, loc)
end

{value, state, env}
Expand Down
16 changes: 8 additions & 8 deletions lib/charms/diagnostic.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ defmodule Charms.Diagnostic do
@moduledoc false
@doc false
alias Beaver.MLIR

def meta_from_loc(%MLIR.Location{} = loc) do
c = Regex.named_captures(~r/(?<file>.+):(?<line>\d+):(?<column>\d+)/, MLIR.to_string(loc))
[file: c["file"], line: c["line"] || 0]
end

def compile_error_message(%Beaver.MLIR.Diagnostic{} = d) do
loc = to_string(MLIR.location(d))
txt = to_string(d)

case txt do
"" ->
{:error, "No diagnostic message"}

note ->
c =
Regex.named_captures(
~r/(?<file>.+):(?<line>\d+):(?<column>\d+)/,
loc
)

{:ok, [file: c["file"], line: c["line"] || 0, description: note]}
{:ok, meta_from_loc(MLIR.location(d)) ++ [description: note]}
end
end

Expand Down
39 changes: 24 additions & 15 deletions lib/charms/jit.ex
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ defmodule Charms.JIT do

defp do_init(modules) when is_list(modules) do
ctx = MLIR.Context.create()

modules
|> Enum.map(fn
m when is_atom(m) ->
Expand All @@ -81,21 +82,29 @@ defmodule Charms.JIT do
raise ArgumentError, "Unexpected module type: #{inspect(other)}"
end)
|> then(fn op ->
{res, _} = MLIR.Context.with_diagnostics(
ctx,
fn ->
try do
{:ok, op |> merge_modules() |> jit_of_mod()}
rescue
err ->
{:error, err}
end
end,
fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end
)
case res do
{:ok, jit} -> jit
{:error, err} -> raise err
{res, msg} =
MLIR.Context.with_diagnostics(
ctx,
fn ->
try do
{:ok, op |> merge_modules() |> jit_of_mod()}
rescue
err ->
{:error, err, __STACKTRACE__}
end
end,
fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end
)

case {res, msg} do
{{:ok, jit}, nil} ->
jit

{{:error, _, st}, {:ok, d_msg}} ->
reraise CompileError, d_msg, st

{{:error, err, st}, _} ->
reraise err, st
end
end)
|> then(
Expand Down
23 changes: 7 additions & 16 deletions lib/charms/kernel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,10 @@ defmodule Charms.Kernel do
use Charms.Intrinsic
alias Charms.Intrinsic.Opts
alias Beaver.MLIR.Dialect.Arith
@binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :*]
@binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :*, :/]
@unary_ops [:!]
@binary_macro_ops [:&&, :||]

defp constant_of_same_type(i, v, %Opts{ctx: ctx, blk: blk, loc: loc}) do
mlir ctx: ctx, blk: blk do
t = MLIR.CAPI.mlirValueGetType(v)

if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do
Arith.constant(value: Attribute.integer(t, i), loc: loc) >>> t
else
raise ArgumentError, "Not an integer type for constant, #{to_string(t)}"
end
end
end

@compare_ops [:!=, :==, :>, :>=, :<, :<=]
defp i_predicate(:!=), do: :ne
defp i_predicate(:==), do: :eq
Expand Down Expand Up @@ -51,6 +39,9 @@ defmodule Charms.Kernel do
:* ->
Arith.muli(operands, loc: loc) >>> type

:/ ->
Arith.divsi(operands, loc: loc) >>> type

_ ->
raise ArgumentError, "Unsupported operator: #{inspect(op)}"
end
Expand All @@ -59,15 +50,15 @@ defmodule Charms.Kernel do

for name <- @binary_ops ++ @binary_macro_ops do
defintrinsic unquote(name)(left, right) do
opts = %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__
%Opts{ctx: ctx, blk: blk, loc: loc} = __IR__

{operands, type} =
case {left, right} do
{%MLIR.Value{} = v, i} when is_integer(i) ->
[v, constant_of_same_type(i, v, opts)]
[v, Charms.Constant.from_literal(i, v, ctx, blk, loc)]

{i, %MLIR.Value{} = v} when is_integer(i) ->
[constant_of_same_type(i, v, opts), v]
[Charms.Constant.from_literal(i, v, ctx, blk, loc), v]

{%MLIR.Value{}, %MLIR.Value{}} ->
if not MLIR.equal?(MLIR.Value.type(left), MLIR.Value.type(right)) do
Expand Down
2 changes: 1 addition & 1 deletion test/const_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ defmodule ConstTest do
end

assert_raise CompileError,
~r"test/const_test.exs:13: Unsupported type for const macro: tensor<\*xf64>",
~r"test/const_test.exs:13: Not a supported type for constant, tensor<\*xf64>",
f
end
end
2 changes: 2 additions & 0 deletions test/defm_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ defmodule AddTwoInt do
a = Pointer.load(i32(), ptr_a)
b = Pointer.load(i32(), ptr_b)
sum = value llvm.add(a, b) :: i32()
sum = sum / 1
sum = sum + 1 - 1
term = enif_make_int(env, sum)
func.return(term)
else
Expand Down

0 comments on commit 37abfda

Please sign in to comment.