Skip to content

Commit

Permalink
Update for beaver 0.4.0 (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper authored Dec 5, 2024
1 parent 456ac91 commit 123ab77
Show file tree
Hide file tree
Showing 17 changed files with 154 additions and 168 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/elixir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Set up Elixir
uses: erlef/setup-beam@v1
with:
elixir-version: "1.17-dev" # [Required] Define the Elixir version
elixir-version: "1.17" # [Required] Define the Elixir version
otp-version: "26.0" # [Required] Define the Erlang/OTP version
- name: Restore dependencies cache
uses: actions/cache@v3
Expand Down
2 changes: 1 addition & 1 deletion guides/programming-with-charms.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

```elixir
Mix.install([
{:charms, "~> 0.1.1"}
{:charms, "~> 0.1.2"}
])
```

Expand Down
2 changes: 1 addition & 1 deletion lib/charms/debug.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ defmodule Charms.Debug do
MLIR.dump!(op)

_ ->
MLIR.Transforms.print_ir(op)
MLIR.Transform.print_ir(op)
end
else
op
Expand Down
59 changes: 37 additions & 22 deletions lib/charms/defm/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ defmodule Charms.Defm.Definition do
end

if msg = Beaver.Walker.attributes(op)["msg"] do
MLIR.Attribute.unwrap(msg)
|> MLIR.StringRef.to_string()
MLIR.Attribute.unwrap(msg) |> MLIR.to_string()
else
"Poison operation detected in the IR. #{to_string(op)}"
end
Expand All @@ -180,12 +179,12 @@ defmodule Charms.Defm.Definition do
|> MLIR.CAPI.mlirFunctionTypeGetNumResults()
|> Beaver.Native.to_term() do
0 ->
mlir ctx: MLIR.CAPI.mlirOperationGetContext(func), block: b do
mlir ctx: MLIR.CAPI.mlirOperationGetContext(func), blk: b do
Func.return(loc: MLIR.Operation.location(func)) >>> []
end

1 ->
mlir ctx: MLIR.CAPI.mlirOperationGetContext(last_op), block: b do
mlir ctx: MLIR.CAPI.mlirOperationGetContext(last_op), blk: b do
results = Beaver.Walker.results(last_op) |> Enum.to_list()
Func.return(results, loc: MLIR.Operation.location(last_op)) >>> []
end
Expand Down Expand Up @@ -225,10 +224,13 @@ defmodule Charms.Defm.Definition do
|> then(fn {_, acc} -> MapSet.to_list(acc) end)
end

def do_compile(ctx, definitions, diagnostic_server) do
m = MLIR.Module.create(ctx, "")
def do_compile(ctx, definitions) do
# this function might be called at compile time, so we need to ensure the application is started
:ok = Application.ensure_started(:kinda)
:ok = Application.ensure_started(:beaver)
m = MLIR.Module.create("", ctx: ctx)

mlir ctx: ctx, block: MLIR.Module.body(m) do
mlir ctx: ctx, blk: MLIR.Module.body(m) do
mlir_expander = %Charms.Defm.Expander{
ctx: ctx,
blk: Beaver.Env.block(),
Expand Down Expand Up @@ -266,20 +268,20 @@ defmodule Charms.Defm.Definition do

m
|> Charms.Debug.print_ir_pass()
|> MLIR.Pass.Composer.nested(
|> Beaver.Composer.nested(
"func.func",
{"append_missing_return", "func.func", &append_missing_return/1}
)
|> MLIR.Pass.Composer.nested("func.func", Charms.Defm.Pass.CreateAbsentFunc)
|> MLIR.Pass.Composer.append({"check-poison", "builtin.module", &check_poison!/1})
|> MLIR.Transforms.canonicalize()
|> Beaver.Composer.nested("func.func", Charms.Defm.Pass.CreateAbsentFunc)
|> Beaver.Composer.append({"check-poison", "builtin.module", &check_poison!/1})
|> MLIR.Transform.canonicalize()
|> then(fn op ->
case MLIR.Pass.Composer.run(op, print: Charms.Debug.step_print?()) do
case Beaver.Composer.run(op, print: Charms.Debug.step_print?()) do
{:ok, op} ->
op

{:error, msg} ->
raise_compile_error(__ENV__, diagnostic_server, msg)
raise_compile_error(__ENV__, msg)
end
end)
|> then(&{MLIR.to_string(&1, bytecode: true), referenced_modules(&1)})
Expand All @@ -296,15 +298,28 @@ defmodule Charms.Defm.Definition do
"""
def compile(definitions) when is_list(definitions) do
ctx = MLIR.Context.create()
{:ok, diagnostic_server} = GenServer.start(Beaver.Diagnostic.Server, [])
diagnostic_handler_id = Beaver.Diagnostic.attach(ctx, diagnostic_server)

try do
do_compile(ctx, definitions, diagnostic_server)
after
Beaver.Diagnostic.detach(ctx, diagnostic_handler_id)
MLIR.Context.destroy(ctx)
:ok = GenServer.stop(diagnostic_server)
{res, msg} = MLIR.Context.with_diagnostics(
ctx,
fn ->
try do
{:ok, do_compile(ctx, definitions)}
rescue
err ->
{:error, err}
end
end,
fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end
)
case {res, msg} do
{{:ok, {mlir, mods}}, nil} ->
MLIR.Context.destroy(ctx)
{mlir, mods}

{_, {:ok, d_msg}} ->
raise CompileError, d_msg

{{:error, err}, _} ->
raise err
end
end

Expand Down
53 changes: 26 additions & 27 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ defmodule Charms.Defm.Expander do
"""
use Beaver
alias MLIR.Attribute
alias MLIR.Dialect.{Func, CF, SCF, MemRef, Index, Arith, Ub, LLVM}
alias MLIR.Dialect.{Func, CF, SCF, MemRef, Index, Arith, UB, LLVM}
require Func
import Charms.Diagnostic, only: :macros
# Define the environment we will use for expansion.
Expand Down Expand Up @@ -67,7 +67,6 @@ defmodule Charms.Defm.Expander do
"""
def expand(ast, file) do
ctx = MLIR.Context.create()
Beaver.Diagnostic.attach(ctx)
available_ops = MapSet.new(MLIR.Dialect.Registry.ops(:all, ctx: ctx))

mlir = %__MODULE__{
Expand Down Expand Up @@ -102,12 +101,12 @@ defmodule Charms.Defm.Expander do

defp create_call(mod, name, args, types, state, env) do
op =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
%Beaver.SSA{
op: "func.call",
arguments: args ++ [callee: Attribute.flat_symbol_ref(mangling(mod, name))],
ctx: Beaver.Env.context(),
block: Beaver.Env.block(),
blk: Beaver.Env.block(),
loc: MLIR.Location.from_env(env)
}
|> Beaver.SSA.put_results(types)
Expand All @@ -118,8 +117,8 @@ defmodule Charms.Defm.Expander do
end

defp create_poison(msg, state, env) do
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
Ub.poison(msg: MLIR.Attribute.string(msg), loc: MLIR.Location.from_env(env)) >>>
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
UB.poison(msg: MLIR.Attribute.string(msg), loc: MLIR.Location.from_env(env)) >>>
~t{none}
end
|> then(&{&1, state, env})
Expand All @@ -130,7 +129,7 @@ defmodule Charms.Defm.Expander do
update_in(
state.mlir.dependence_modules,
&Map.put_new_lazy(&1, module, fn ->
MLIR.Module.create(state.mlir.ctx, module.__ir__()) |> MLIR.Operation.from_module()
MLIR.Module.create(module.__ir__(), ctx: state.mlir.ctx) |> MLIR.Operation.from_module()
end)
)
|> then(&{&1.mlir.dependence_modules[module], &1})
Expand Down Expand Up @@ -175,7 +174,7 @@ defmodule Charms.Defm.Expander do
MLIR.StringRef.create(mangling(mod, name))
)

if MLIR.is_null(sym) do
if MLIR.null?(sym) do
raise_compile_error(
env,
"function #{name} not found in module #{inspect(mod)}"
Expand Down Expand Up @@ -274,7 +273,7 @@ defmodule Charms.Defm.Expander do

defp expand_std(Enum, :reduce, args, state, env) do
while =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
[l, init, f] = args
{l, state, env} = expand(l, state, env)
{init, state, env} = expand(init, state, env)
Expand Down Expand Up @@ -343,7 +342,7 @@ defmodule Charms.Defm.Expander do
defp expand_std(String, :length, args, state, env) do
{string, state, env} = expand(args, state, env)

mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
zero = Index.constant(value: Attribute.index(0)) >>> Type.index()
len = MemRef.dim(string, zero) >>> :infer
end
Expand Down Expand Up @@ -371,7 +370,7 @@ defmodule Charms.Defm.Expander do

state =
with [head_arg_type | _] <- arg_types,
MLIR.Type.equal?(head_arg_type, Beaver.ENIF.Type.env(ctx: state.mlir.ctx)),
MLIR.equal?(head_arg_type, Beaver.ENIF.Type.env(ctx: state.mlir.ctx)),
[{:env, _, nil} | _] <- args do
a = MLIR.Block.get_arg!(Beaver.Env.block(), 0)
put_in(state.mlir.enif_env, a)
Expand Down Expand Up @@ -430,7 +429,7 @@ defmodule Charms.Defm.Expander do
op: op,
arguments: args,
ctx: state.mlir.ctx,
block: state.mlir.blk,
blk: state.mlir.blk,
loc: MLIR.Location.from_env(env),
results: if(has_implemented_inference(op, state.mlir.ctx), do: [:infer], else: [])
}
Expand Down Expand Up @@ -473,7 +472,7 @@ defmodule Charms.Defm.Expander do
args,
%Charms.Intrinsic.Opts{
ctx: state.mlir.ctx,
block: state.mlir.blk,
blk: state.mlir.blk,
loc: loc
}
])
Expand Down Expand Up @@ -754,10 +753,10 @@ defmodule Charms.Defm.Expander do

defp expand({:^, _meta, [arg]}, state, %{context: context} = env) do
{b, state, env} = expand(arg, state, %{env | context: nil})
match?(%MLIR.Block{}, b) || raise Beaver.EnvNotFoundError, MLIR.Block
match?(%MLIR.Block{}, b) || raise_compile_error(env, "Expected a block, got: #{inspect(b)}")

br =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
CF.br({b, []}) >>> []
end

Expand Down Expand Up @@ -845,14 +844,14 @@ defmodule Charms.Defm.Expander do
found = MLIR.CAPI.mlirSymbolTableLookup(s_table, MLIR.StringRef.create(sym_name))
loc = MLIR.Location.from_env(env)

mlir ctx: state.mlir.ctx, block: MLIR.Module.body(state.mlir.mod) do
if MLIR.is_null(found) do
mlir ctx: state.mlir.ctx, blk: MLIR.Module.body(state.mlir.mod) do
if MLIR.null?(found) do
MemRef.global(ast, sym_name: Attribute.string(sym_name), loc: loc) >>> :infer
else
found
end
|> then(
&mlir block: state.mlir.blk do
&mlir blk: state.mlir.blk do
name = Attribute.flat_symbol_ref(Attribute.unwrap(&1[:sym_name]))
MemRef.get_global(name: name, loc: loc) >>> Attribute.unwrap(&1[:type])
end
Expand Down Expand Up @@ -880,15 +879,15 @@ defmodule Charms.Defm.Expander do

# Expands a nil clause body in an if statement, yielding no value.
defp expand_if_clause_body(nil, state, _env) do
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
SCF.yield() >>> []
[]
end
end

# Expands a non-nil clause body in an if statement, yielding the last evaluated value.
defp expand_if_clause_body(clause_body, state, env) do
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
{ret, _, _} = expand(clause_body, state, env)

case ret do
Expand Down Expand Up @@ -961,7 +960,7 @@ defmodule Charms.Defm.Expander do
parent_block = state.mlir.blk

f =
mlir ctx: state.mlir.ctx, block: parent_block do
mlir ctx: state.mlir.ctx, blk: parent_block do
{ret_types, state, env} = ret_types |> expand(state, env)
{arg_types, state, env} = arg_types |> expand(state, env)

Expand Down Expand Up @@ -1012,7 +1011,7 @@ defmodule Charms.Defm.Expander do
{condition, state, env} = expand(condition, state, env)

v =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
true_body =
block do
expand(true_body, put_in(state.mlir.blk, Beaver.Env.block()), env)
Expand Down Expand Up @@ -1044,7 +1043,7 @@ defmodule Charms.Defm.Expander do
loc = MLIR.Location.from_env(env)

v =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
cond_type = MLIR.Value.type(condition)
bool_type = Type.i1(ctx: state.mlir.ctx)
# Ensure the condition is a i1, if not compare it to 0
Expand Down Expand Up @@ -1085,7 +1084,7 @@ defmodule Charms.Defm.Expander do
loc = MLIR.Location.from_env(env)

v =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
SCF.while loc: loc do
region do
block _() do
Expand Down Expand Up @@ -1119,7 +1118,7 @@ defmodule Charms.Defm.Expander do
{ptr, state, env} = expand(ptr, state, env)

v =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
zero = Index.constant(value: Attribute.index(0)) >>> Type.index()
lower_bound = zero
upper_bound = Index.casts(len) >>> Type.index()
Expand Down Expand Up @@ -1170,7 +1169,7 @@ defmodule Charms.Defm.Expander do
op: op,
arguments: args,
ctx: state.mlir.ctx,
block: state.mlir.blk,
blk: state.mlir.blk,
loc: MLIR.Location.from_env(env)
}
|> Beaver.SSA.put_results(return_types)
Expand Down Expand Up @@ -1209,7 +1208,7 @@ defmodule Charms.Defm.Expander do
{type, state, env} = expand(type, state, env)

value =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
loc = MLIR.Location.from_env(env)

cond do
Expand Down
8 changes: 4 additions & 4 deletions lib/charms/defm/pass/create_absent_func.ex
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ defmodule Charms.Defm.Pass.CreateAbsentFunc do

@default_visibility "private"
# create absent if it is a function not found in the symbol table
defp create_func(ctx, block, symbol_table, ir, created) do
defp create_func(ctx, blk, symbol_table, ir, created) do
with op = %MLIR.Operation{} <- ir,
"func.call" <- MLIR.Operation.name(op),
{name, arg_types, ret_types} <- decompose(op),
true <- MLIR.is_null(mlirSymbolTableLookup(symbol_table, name)),
name_str <- MLIR.StringRef.to_string(name),
true <- MLIR.null?(mlirSymbolTableLookup(symbol_table, name)),
name_str <- MLIR.to_string(name),
false <- MapSet.member?(created, name_str) do
mlir ctx: ctx, block: block do
mlir ctx: ctx, blk: blk do
{arg_types, ret_types} =
if s = Beaver.ENIF.signature(ctx, String.to_atom(name_str)) do
s
Expand Down
Loading

0 comments on commit 123ab77

Please sign in to comment.