Skip to content

Commit 72cb5c6

Browse files
committed
simple function call
1 parent 146d46d commit 72cb5c6

File tree

3 files changed

+101
-34
lines changed

3 files changed

+101
-34
lines changed

luisa_lang/_builtin_decor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,23 +271,26 @@ def wrapper(*args, __lc_ctx__: Optional[TraceContext] = None, **kwargs):
271271
)
272272
# print(instantiated_func.return_type)
273273
# func_tracer = current_func()
274-
rt = instantiated_func.return_type
274+
rt = instantiated_func.return_jitvar_type
275+
assert issubclass(rt, JitVar), f"Return type {rt} is not a JitVar"
275276
jitvar_args: List[JitVar] = []
276277
for arg in pytree_args:
277278
jitvar_args.extend(arg.collect_jitvars())
278279
for k, v in pytree_kwargs.items():
279280
jitvar_args.extend(v.collect_jitvars())
280281
# TODO: handle kwargs properly
281282
if not __lc_ctx__.is_top_level:
282-
push_to_current_bb(
283+
ret_node = push_to_current_bb(
283284
hir.Call(
284285
op=instantiated_func,
285286
args=[x.symbolic().node for x in jitvar_args],
286-
type=rt,
287+
type=instantiated_func.return_type,
287288
)
288289
)
290+
return rt.from_hir_node(ret_node)
289291
else:
290292
__lc_ctx__.top_level_func = instantiated_func
293+
291294

292295
# Copy over important attributes from the original function
293296
wrapper.__name__ = f.__name__

luisa_lang/hir.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ class Function:
604604
params: List["Var"]
605605
locals: List["Var"]
606606
body: "BasicBlock"
607+
return_jitvar_type: typing.Type['Any']
607608
return_type: Type
608609

609610
def __init__(
@@ -612,13 +613,15 @@ def __init__(
612613
params: List["Var"],
613614
locals: List["Var"],
614615
body: "BasicBlock",
615-
return_type: Type,
616+
return_jitvar_type: typing.Type['Any'],
617+
return_type: Type
616618
) -> None:
617619
self.name = name
618620
self.params = params
619621
self.locals = locals
620622
self.body = body
621623
self.return_type = return_type
624+
self.return_jitvar_type = return_jitvar_type
622625

623626

624627
class Node:

luisa_lang/lang_runtime.py

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,19 @@
88
from luisa_lang.utils import IdentityDict, check_type, is_generic_class
99
import luisa_lang.hir as hir
1010
from luisa_lang.hir import PyTreeStructure
11-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast
11+
from typing import (
12+
Any,
13+
Callable,
14+
Dict,
15+
List,
16+
Mapping,
17+
Optional,
18+
Sequence,
19+
Tuple,
20+
Type,
21+
Union,
22+
cast,
23+
)
1224

1325

1426
class Scope:
@@ -43,12 +55,12 @@ class FuncTracer:
4355
locals: List[hir.Var]
4456
params: List[hir.Var]
4557
scopes: List[Scope]
46-
ret_type: hir.Type | None
58+
ret_type: Type["JitVar"] | None
4759
func_globals: Dict[str, Any]
4860
name: str
4961
entry_bb: hir.BasicBlock
5062

51-
def __init__(self, name:str, func_globals: Dict[str, Any]):
63+
def __init__(self, name: str, func_globals: Dict[str, Any]):
5264
self.locals = []
5365
self.py_locals = {}
5466
self.params = []
@@ -76,7 +88,6 @@ def create_var(self, name: str, ty: hir.Type, is_param: bool) -> hir.Var:
7688
self.params.append(var)
7789
return var
7890

79-
8091
def add_py_var(self, name: str, obj: object):
8192
assert not isinstance(obj, JitVar)
8293
if name in self.py_locals:
@@ -119,7 +130,7 @@ def set_var(self, key: str, value: Any) -> None:
119130
else:
120131
self.py_locals[key] = value
121132

122-
def check_return_type(self, ty:hir.Type):
133+
def check_return_type(self, ty: Type["JitVar"]) -> None:
123134
if self.ret_type is None:
124135
self.ret_type = ty
125136
else:
@@ -130,7 +141,7 @@ def check_return_type(self, ty:hir.Type):
130141

131142
def cur_bb(self) -> hir.BasicBlock:
132143
return self.scopes[-1].bb
133-
144+
134145
def set_cur_bb(self, bb: hir.BasicBlock) -> None:
135146
"""
136147
Set the current basic block to `bb`
@@ -142,7 +153,14 @@ def finalize(self) -> hir.Function:
142153
assert len(self.scopes) == 1
143154
entry_bb = self.entry_bb
144155
assert self.ret_type is not None
145-
return hir.Function(self.name, self.params, self.locals, entry_bb, self.ret_type)
156+
return hir.Function(
157+
self.name,
158+
self.params,
159+
self.locals,
160+
entry_bb,
161+
self.ret_type,
162+
self.ret_type.hir_type(),
163+
)
146164

147165

148166
FUNC_STACK: List[FuncTracer] = []
@@ -160,7 +178,6 @@ def push_to_current_bb[T: hir.Node](node: T) -> T:
160178
return current_func().cur_bb().append(node)
161179

162180

163-
164181
class Symbolic:
165182
node: hir.Value
166183
scope: Scope
@@ -178,7 +195,9 @@ class FlattenedTree:
178195
children: List["FlattenedTree"]
179196

180197
def __init__(
181-
self, metadata: Tuple[Type[Any], Tuple[Any], Any], children: List["FlattenedTree"]
198+
self,
199+
metadata: Tuple[Type[Any], Tuple[Any], Any],
200+
children: List["FlattenedTree"],
182201
):
183202
self.metadata = metadata
184203
self.children = children
@@ -214,8 +233,8 @@ def structure(self) -> hir.PyTreeStructure:
214233
return hir.PyTreeStructure(
215234
(typ, self.metadata[1], self.metadata[2]), children
216235
)
217-
218-
def collect_jitvars(self) -> List['JitVar']:
236+
237+
def collect_jitvars(self) -> List["JitVar"]:
219238
"""
220239
Collect all JitVar instances from the flattened tree
221240
"""
@@ -275,7 +294,7 @@ class JitVar:
275294
__symbolic__: Optional[Symbolic]
276295
dtype: type[Any]
277296

278-
def __init__(self, dtype:type[Any]):
297+
def __init__(self, dtype: type[Any]):
279298
"""
280299
Zero-initialize a variable with given data type
281300
"""
@@ -316,6 +335,14 @@ def from_hir_node[T: JitVar](cls: type[T], node: hir.Value) -> T:
316335
instance.dtype = cls
317336
return instance
318337

338+
@classmethod
339+
def hir_type(cls) -> hir.Type:
340+
"""
341+
Get the HIR type of the JitVar
342+
"""
343+
# TODO: handle generic types
344+
return hir.get_dsl_type(cls).default()
345+
319346
def symbolic(self) -> Symbolic:
320347
"""
321348
Retrieve the internal symbolic representation of the variable. This is used for internal DSL code generation.
@@ -423,7 +450,8 @@ def unflatten_primitive(tree: FlattenedTree) -> Any:
423450

424451
def flatten_list(obj: List[Any]) -> FlattenedTree:
425452
return FlattenedTree(
426-
(list, cast(Tuple[Any, ...], tuple()), None), [tree_flatten(o, True) for o in obj]
453+
(list, cast(Tuple[Any, ...], tuple()), None),
454+
[tree_flatten(o, True) for o in obj],
427455
)
428456

429457
def unflatten_list(tree: FlattenedTree) -> List[Any]:
@@ -435,7 +463,8 @@ def unflatten_list(tree: FlattenedTree) -> List[Any]:
435463

436464
def flatten_tuple(obj: Tuple[Any, ...]) -> FlattenedTree:
437465
return FlattenedTree(
438-
(tuple, cast(Tuple[Any, ...], tuple()), None), [tree_flatten(o, True) for o in obj]
466+
(tuple, cast(Tuple[Any, ...], tuple()), None),
467+
[tree_flatten(o, True) for o in obj],
439468
)
440469

441470
def unflatten_tuple(tree: FlattenedTree) -> Tuple[Any, ...]:
@@ -491,7 +520,9 @@ def create_intrinsic_node[T: JitVar](
491520
elif isinstance(a, hir.Value):
492521
nodes.append(a)
493522
else:
494-
raise ValueError(f"Argument [{i}] `{a}` of type {type(a)} is not a valid DSL variable or HIR node")
523+
raise ValueError(
524+
f"Argument [{i}] `{a}` of type {type(a)} is not a valid DSL variable or HIR node"
525+
)
495526
if ret_type is not None:
496527
ret_dsl_type = hir.get_dsl_type(ret_type).default()
497528
if ret_dsl_type is None:
@@ -500,25 +531,28 @@ def create_intrinsic_node[T: JitVar](
500531
ret_dsl_type = hir.UnitType()
501532
return push_to_current_bb(hir.Intrinsic(name, nodes, ret_dsl_type))
502533

534+
503535
def __escape__(x: Any) -> Any:
504536
return x
505537

538+
506539
def __intrinsic_checked__[T](
507540
name: str, arg_types: Sequence[Any], ret_type: type[T], *args
508541
) -> T:
509542
"""
510543
Call an intrinsic function with type checking.
511544
"""
512-
assert len(args) == len(arg_types), (
513-
f"Intrinsic {name} expects {len(arg_types)} arguments, got {len(args)}"
514-
)
545+
assert len(args) == len(
546+
arg_types
547+
), f"Intrinsic {name} expects {len(arg_types)} arguments, got {len(args)}"
515548
for i, (arg, arg_type) in enumerate(zip(args, arg_types)):
516549
if not check_type(arg_type, arg):
517550
raise ValueError(
518551
f"Argument {i} of intrinsic {name} is not of type {arg_type}, got {type(arg)}"
519552
)
520553
return __intrinsic__(name, ret_type, *args)
521554

555+
522556
def __intrinsic__[T](name: str, ret_type: type[T], *args) -> T:
523557
"""
524558
Call an intrinsic function. This function does not check the arguemnts.
@@ -572,19 +606,21 @@ def on_exit(self) -> None:
572606
"""
573607
pass
574608

609+
575610
class ScopeGuard:
576611
def __enter__(self) -> Scope:
577612
"""
578613
Enter a new scope
579614
"""
580615
return current_func().push_scope()
581-
616+
582617
def __exit__(self, exc_type, exc_val, exc_tb):
583618
"""
584619
Exit the current scope
585620
"""
586621
current_func().pop_scope()
587622

623+
588624
class IfFrame(ControlFlowFrame):
589625
static_cond: Optional[bool]
590626
true_bb: Optional[hir.BasicBlock]
@@ -625,12 +661,16 @@ def on_exit(self) -> None:
625661
cond = self.cond
626662
assert isinstance(cond, JitVar), "Condition must be a DSL variable"
627663
merge_bb = hir.BasicBlock()
628-
if_stmt = hir.If(cond.symbolic().node,
629-
cast(hir.BasicBlock, self.true_bb),
630-
cast(hir.BasicBlock, self.false_bb), merge_bb)
664+
if_stmt = hir.If(
665+
cond.symbolic().node,
666+
cast(hir.BasicBlock, self.true_bb),
667+
cast(hir.BasicBlock, self.false_bb),
668+
merge_bb,
669+
)
631670
push_to_current_bb(if_stmt)
632671
current_func().set_cur_bb(merge_bb)
633672

673+
634674
class ControlFrameGuard[T: ControlFlowFrame]:
635675
cf_type: type[T]
636676
args: Tuple[Any, ...]
@@ -707,31 +747,52 @@ def __exit__(self, exc_type, exc_val, exc_tb):
707747
"GtE": ["__ge__", "__le__"],
708748
}
709749

750+
751+
class LineTable:
752+
span_of_line: Dict[int, hir.Span]
753+
754+
710755
class TraceContext:
711756
cf_frame: ControlFlowFrame
712757
is_top_level: bool
713758
top_level_func: Optional[hir.Function]
759+
line_table: Optional[LineTable] # for better error reporting
760+
current_line: Optional[int] = None
714761

715762
def __init__(self, is_top_level):
716763
self.cf_frame = ControlFlowFrame(parent=None)
717764
self.is_top_level = is_top_level
718765
self.top_level_func = None
719766

767+
def set_line_table(self, line_table: LineTable) -> None:
768+
self.line_table = line_table
769+
770+
def set_current_line(self, line: int) -> None:
771+
self.current_line = line
772+
773+
def current_span(self) -> hir.Span | None:
774+
"""
775+
Get the current span for error reporting
776+
"""
777+
if self.line_table is not None and self.current_line is not None:
778+
return self.line_table.span_of_line.get(self.current_line, None)
779+
return None
780+
720781
def is_parent_static(self) -> bool:
721782
return self.cf_frame.is_static
722783

723784
def if_(self, cond: Any) -> ControlFrameGuard[IfFrame]:
724785
return ControlFrameGuard(self, IfFrame, cond)
725-
786+
726787
def scope(self) -> ScopeGuard:
727788
return ScopeGuard()
728789

729790
def return_(self, expr: JitVar) -> None:
730791
"""
731792
Return a value from the current function
732793
"""
733-
ty = expr._symbolic_type()
734-
current_func().check_return_type(ty)
794+
assert isinstance(expr, JitVar), "Return expression must be a DSL variable"
795+
current_func().check_return_type(type(expr)) # TODO: handle generics
735796
push_to_current_bb(hir.Return(expr.symbolic().node))
736797

737798
def redirect_binary(self, op, x, y):
@@ -745,7 +806,7 @@ def redirect_binary(self, op, x, y):
745806
raise ValueError(
746807
f"Binary operation {op} not supported for {type(x)} and {type(y)}"
747808
)
748-
809+
749810
def redirect_cmp(self, op, x, y):
750811
op, rop = CMP_OP_TO_METHOD_NAMES[op]
751812
if hasattr(x, op):
@@ -758,11 +819,11 @@ def redirect_cmp(self, op, x, y):
758819
)
759820

760821
def redirect_call(self, f, *args, **kwargs):
761-
return f(*args, **kwargs, __lc_ctx__=self) # TODO: shoould not always pass self
822+
return f(*args, **kwargs, __lc_ctx__=self) # TODO: shoould not always pass self
762823

763824
def intrinsic(self, f, *args, **kwargs):
764825
return __intrinsic__(f, *args, **kwargs)
765-
826+
766827
def intrinsic_checked(self, f, arg_types, ret_type, *args):
767828
return __intrinsic_checked__(f, arg_types, ret_type, *args)
768829

@@ -837,7 +898,7 @@ def _invoke_function_tracer(
837898
trace_ctx = TraceContext(False)
838899

839900
# args is Type | object
840-
func_tracer = FuncTracer(f.__name__.replace('.','_'), globalns)
901+
func_tracer = FuncTracer(f.__name__.replace(".", "_"), globalns)
841902
FUNC_STACK.append(func_tracer)
842903
try:
843904
args_vars, kwargs_vars, jit_vars = _encode_func_args(args)
@@ -852,7 +913,7 @@ class KernelTracer:
852913
top_level_tracer: FuncTracer
853914

854915
def __init__(self, func_globals: Dict[str, Any]):
855-
self.top_level_tracer = FuncTracer('__kernel__', func_globals)
916+
self.top_level_tracer = FuncTracer("__kernel__", func_globals)
856917

857918
def __enter__(self) -> FuncTracer:
858919
FUNC_STACK.append(self.top_level_tracer)

0 commit comments

Comments
 (0)