Skip to content

Commit

Permalink
[Runtime] Support arbitrary bitwidth integer in LLVM execution engine (
Browse files Browse the repository at this point in the history
…#37)

* Add T test

* Add make_any_bitwidth_np_array

* Reorg new_arg

* Fix reverting struct array to int array

* Fix output type

* Add i8 i16 return

* Add support for i<16

* Add typedef to global_var

* Fix return_dtype

* Add np_supported_types

* Fix np_type

* Cast np_type

* Fix pylint

* Add UInt in builder

* Update tests

* Fix return arg

* Fix return same
  • Loading branch information
chhzh123 authored Aug 21, 2023
1 parent b1b33a7 commit df35d4c
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 36 deletions.
6 changes: 4 additions & 2 deletions allo/customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .ir.builder import ASTTransformer
from .ir.infer import TypeInferer
from .ir.transform import get_affine_loop_nests, find_loop_in_bands
from .ir.types import AlloType
from .build_module import _mlir_lower_pipeline, lower_linalg_and_attach_names
from .module import LLVMModule, HLSModule

Expand All @@ -54,9 +55,10 @@ def _get_global_vars(_func):
global_vars = _func.__globals__.copy()

# Get back to the outer-most scope (user-defined function)
# Mainly used to get the annotation definitions, which are probably not defined in __globals__
# Mainly used to get the annotation definitions (shape and type),
# which are probably not defined in __globals__
for name, var in inspect.stack()[2][0].f_locals.items():
if isinstance(var, (int, float)) or inspect.isfunction(var):
if isinstance(var, (int, float, AlloType)) or inspect.isfunction(var):
global_vars[name] = var

freevar_names = _func.__code__.co_freevars
Expand Down
3 changes: 3 additions & 0 deletions allo/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,19 @@ def build_general_binop(ctx, node, lhs, rhs):
ast.Add: {
Float: arith_d.AddFOp,
Int: arith_d.AddIOp,
UInt: arith_d.AddIOp,
Fixed: hcl_d.AddFixedOp,
},
ast.Sub: {
Float: arith_d.SubFOp,
Int: arith_d.SubIOp,
UInt: arith_d.SubIOp,
Fixed: hcl_d.SubFixedOp,
},
ast.Mult: {
Float: arith_d.MulFOp,
Int: arith_d.MulIOp,
UInt: arith_d.MulIOp,
Fixed: hcl_d.MulFixedOp,
},
ast.Div: {
Expand Down
3 changes: 0 additions & 3 deletions allo/ir/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,6 @@ def visit_AnnAssign(ctx, node):
rhs = TypeInferer.visit_constant_tensor(ctx, node)
else:
raise RuntimeError("Unsupported data type")
# assert (
# rhs.dtype == target_dtype
# ), f"Type mismatch, got {rhs.dtype} and {target_dtype} for {node.__class__.__name__} `{node.target.id}`"
if not isinstance(node.value, ast.Constant):
assert (
rhs.shape == target_shape
Expand Down
Loading

0 comments on commit df35d4c

Please sign in to comment.