From 37abfda2fe0b1c9ad8ffeb7cd80891714e69fdfd Mon Sep 17 00:00:00 2001 From: Shenghang Tsai Date: Sat, 14 Dec 2024 10:57:13 +0800 Subject: [PATCH] Add `Charms.Constant.from_literal/5` (#56) --- lib/charms/constant.ex | 29 ++++++++++++++++++++++++++ lib/charms/defm/definition.ex | 27 +++++++++++++----------- lib/charms/defm/expander.ex | 12 +---------- lib/charms/diagnostic.ex | 16 +++++++------- lib/charms/jit.ex | 39 +++++++++++++++++++++-------------- lib/charms/kernel.ex | 23 +++++++-------------- test/const_test.exs | 2 +- test/defm_test.exs | 2 ++ 8 files changed, 87 insertions(+), 63 deletions(-) create mode 100644 lib/charms/constant.ex diff --git a/lib/charms/constant.ex b/lib/charms/constant.ex new file mode 100644 index 0000000..24d72ed --- /dev/null +++ b/lib/charms/constant.ex @@ -0,0 +1,29 @@ +defmodule Charms.Constant do + @moduledoc false + use Beaver + alias Beaver.MLIR.Dialect.{Arith, Index} + + def from_literal(literal, %MLIR.Value{} = v, ctx, blk, loc) do + t = MLIR.CAPI.mlirValueGetType(v) + from_literal(literal, t, ctx, blk, loc) + end + + def from_literal(literal, %MLIR.Type{} = t, ctx, blk, loc) do + mlir ctx: ctx, blk: blk do + cond do + MLIR.Type.integer?(t) -> + Arith.constant(value: Attribute.integer(t, literal), loc: loc) >>> t + + MLIR.Type.float?(t) -> + Arith.constant(value: Attribute.float(t, literal), loc: loc) >>> t + + MLIR.Type.index?(t) -> + Index.constant(value: Attribute.index(literal), loc: loc) >>> t + + true -> + loc = Beaver.Deferred.create(loc, ctx) + raise CompileError, Charms.Diagnostic.meta_from_loc(loc) ++ [description: "Not a supported type for constant, #{to_string(t)}"] + end + end + end +end diff --git a/lib/charms/defm/definition.ex b/lib/charms/defm/definition.ex index 97d2d81..09afbc5 100644 --- a/lib/charms/defm/definition.ex +++ b/lib/charms/defm/definition.ex @@ -298,18 +298,21 @@ defmodule Charms.Defm.Definition do """ def compile(definitions) when is_list(definitions) do ctx = MLIR.Context.create() - {res, msg} = MLIR.Context.with_diagnostics( - ctx, - fn -> - try do - {:ok, do_compile(ctx, definitions)} - rescue - err -> - {:error, err} - end - end, - fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end - ) + + {res, msg} = + MLIR.Context.with_diagnostics( + ctx, + fn -> + try do + {:ok, do_compile(ctx, definitions)} + rescue + err -> + {:error, err} + end + end, + fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end + ) + case {res, msg} do {{:ok, {mlir, mods}}, nil} -> MLIR.Context.destroy(ctx) diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index a2e19c2..cecb3c8 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -1210,17 +1210,7 @@ defmodule Charms.Defm.Expander do value = mlir ctx: state.mlir.ctx, blk: state.mlir.blk do loc = MLIR.Location.from_env(env) - - cond do - MLIR.CAPI.mlirTypeIsAInteger(type) |> Beaver.Native.to_term() -> - Arith.constant(value: Attribute.integer(type, value), loc: loc) >>> type - - MLIR.CAPI.mlirTypeIsAFloat(type) |> Beaver.Native.to_term() -> - Arith.constant(value: Attribute.float(type, value), loc: loc) >>> type - - true -> - raise_compile_error(env, "Unsupported type for const macro: #{to_string(type)}") - end + Charms.Constant.from_literal(value, type, state.mlir.ctx, state.mlir.blk, loc) end {value, state, env} diff --git a/lib/charms/diagnostic.ex b/lib/charms/diagnostic.ex index ba4bea1..69c4d6c 100644 --- a/lib/charms/diagnostic.ex +++ b/lib/charms/diagnostic.ex @@ -2,21 +2,21 @@ defmodule Charms.Diagnostic do @moduledoc false @doc false alias Beaver.MLIR + + def meta_from_loc(%MLIR.Location{} = loc) do + c = Regex.named_captures(~r/(?.+):(?\d+):(?\d+)/, MLIR.to_string(loc)) + [file: c["file"], line: c["line"] || 0] + end + def compile_error_message(%Beaver.MLIR.Diagnostic{} = d) do - loc = to_string(MLIR.location(d)) txt = to_string(d) + case txt do "" -> {:error, "No diagnostic message"} note -> - c = - Regex.named_captures( - ~r/(?.+):(?\d+):(?\d+)/, - loc - ) - - {:ok, [file: c["file"], line: c["line"] || 0, description: note]} + {:ok, meta_from_loc(MLIR.location(d)) ++ [description: note]} end end diff --git a/lib/charms/jit.ex b/lib/charms/jit.ex index ef31ace..b46a19d 100644 --- a/lib/charms/jit.ex +++ b/lib/charms/jit.ex @@ -66,6 +66,7 @@ defmodule Charms.JIT do defp do_init(modules) when is_list(modules) do ctx = MLIR.Context.create() + modules |> Enum.map(fn m when is_atom(m) -> @@ -81,21 +82,29 @@ defmodule Charms.JIT do raise ArgumentError, "Unexpected module type: #{inspect(other)}" end) |> then(fn op -> - {res, _} = MLIR.Context.with_diagnostics( - ctx, - fn -> - try do - {:ok, op |> merge_modules() |> jit_of_mod()} - rescue - err -> - {:error, err} - end - end, - fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end - ) - case res do - {:ok, jit} -> jit - {:error, err} -> raise err + {res, msg} = + MLIR.Context.with_diagnostics( + ctx, + fn -> + try do + {:ok, op |> merge_modules() |> jit_of_mod()} + rescue + err -> + {:error, err, __STACKTRACE__} + end + end, + fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end + ) + + case {res, msg} do + {{:ok, jit}, nil} -> + jit + + {{:error, _, st}, {:ok, d_msg}} -> + reraise CompileError, d_msg, st + + {{:error, err, st}, _} -> + reraise err, st end end) |> then( diff --git a/lib/charms/kernel.ex b/lib/charms/kernel.ex index 85afa6e..3253817 100644 --- a/lib/charms/kernel.ex +++ b/lib/charms/kernel.ex @@ -5,22 +5,10 @@ defmodule Charms.Kernel do use Charms.Intrinsic alias Charms.Intrinsic.Opts alias Beaver.MLIR.Dialect.Arith - @binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :*] + @binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :*, :/] @unary_ops [:!] @binary_macro_ops [:&&, :||] - defp constant_of_same_type(i, v, %Opts{ctx: ctx, blk: blk, loc: loc}) do - mlir ctx: ctx, blk: blk do - t = MLIR.CAPI.mlirValueGetType(v) - - if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do - Arith.constant(value: Attribute.integer(t, i), loc: loc) >>> t - else - raise ArgumentError, "Not an integer type for constant, #{to_string(t)}" - end - end - end - @compare_ops [:!=, :==, :>, :>=, :<, :<=] defp i_predicate(:!=), do: :ne defp i_predicate(:==), do: :eq @@ -51,6 +39,9 @@ defmodule Charms.Kernel do :* -> Arith.muli(operands, loc: loc) >>> type + :/ -> + Arith.divsi(operands, loc: loc) >>> type + _ -> raise ArgumentError, "Unsupported operator: #{inspect(op)}" end @@ -59,15 +50,15 @@ defmodule Charms.Kernel do for name <- @binary_ops ++ @binary_macro_ops do defintrinsic unquote(name)(left, right) do - opts = %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__ + %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__ {operands, type} = case {left, right} do {%MLIR.Value{} = v, i} when is_integer(i) -> - [v, constant_of_same_type(i, v, opts)] + [v, Charms.Constant.from_literal(i, v, ctx, blk, loc)] {i, %MLIR.Value{} = v} when is_integer(i) -> - [constant_of_same_type(i, v, opts), v] + [Charms.Constant.from_literal(i, v, ctx, blk, loc), v] {%MLIR.Value{}, %MLIR.Value{}} -> if not MLIR.equal?(MLIR.Value.type(left), MLIR.Value.type(right)) do diff --git a/test/const_test.exs b/test/const_test.exs index 488f993..23c38d3 100644 --- a/test/const_test.exs +++ b/test/const_test.exs @@ -16,7 +16,7 @@ defmodule ConstTest do end assert_raise CompileError, - ~r"test/const_test.exs:13: Unsupported type for const macro: tensor<\*xf64>", + ~r"test/const_test.exs:13: Not a supported type for constant, tensor<\*xf64>", f end end diff --git a/test/defm_test.exs b/test/defm_test.exs index 8281b32..f9a3c84 100644 --- a/test/defm_test.exs +++ b/test/defm_test.exs @@ -16,6 +16,8 @@ defmodule AddTwoInt do a = Pointer.load(i32(), ptr_a) b = Pointer.load(i32(), ptr_b) sum = value llvm.add(a, b) :: i32() + sum = sum / 1 + sum = sum + 1 - 1 term = enif_make_int(env, sum) func.return(term) else