8
8
from luisa_lang .utils import IdentityDict , check_type , is_generic_class
9
9
import luisa_lang .hir as hir
10
10
from luisa_lang .hir import PyTreeStructure
11
- from typing import Any , Callable , Dict , List , Mapping , Optional , Sequence , Tuple , Type , Union , cast
11
+ from typing import (
12
+ Any ,
13
+ Callable ,
14
+ Dict ,
15
+ List ,
16
+ Mapping ,
17
+ Optional ,
18
+ Sequence ,
19
+ Tuple ,
20
+ Type ,
21
+ Union ,
22
+ cast ,
23
+ )
12
24
13
25
14
26
class Scope :
@@ -43,12 +55,12 @@ class FuncTracer:
43
55
locals : List [hir .Var ]
44
56
params : List [hir .Var ]
45
57
scopes : List [Scope ]
46
- ret_type : hir . Type | None
58
+ ret_type : Type [ "JitVar" ] | None
47
59
func_globals : Dict [str , Any ]
48
60
name : str
49
61
entry_bb : hir .BasicBlock
50
62
51
- def __init__ (self , name :str , func_globals : Dict [str , Any ]):
63
+ def __init__ (self , name : str , func_globals : Dict [str , Any ]):
52
64
self .locals = []
53
65
self .py_locals = {}
54
66
self .params = []
@@ -76,7 +88,6 @@ def create_var(self, name: str, ty: hir.Type, is_param: bool) -> hir.Var:
76
88
self .params .append (var )
77
89
return var
78
90
79
-
80
91
def add_py_var (self , name : str , obj : object ):
81
92
assert not isinstance (obj , JitVar )
82
93
if name in self .py_locals :
@@ -119,7 +130,7 @@ def set_var(self, key: str, value: Any) -> None:
119
130
else :
120
131
self .py_locals [key ] = value
121
132
122
- def check_return_type (self , ty :hir . Type ) :
133
+ def check_return_type (self , ty : Type [ "JitVar" ]) -> None :
123
134
if self .ret_type is None :
124
135
self .ret_type = ty
125
136
else :
@@ -130,7 +141,7 @@ def check_return_type(self, ty:hir.Type):
130
141
131
142
def cur_bb (self ) -> hir .BasicBlock :
132
143
return self .scopes [- 1 ].bb
133
-
144
+
134
145
def set_cur_bb (self , bb : hir .BasicBlock ) -> None :
135
146
"""
136
147
Set the current basic block to `bb`
@@ -142,7 +153,14 @@ def finalize(self) -> hir.Function:
142
153
assert len (self .scopes ) == 1
143
154
entry_bb = self .entry_bb
144
155
assert self .ret_type is not None
145
- return hir .Function (self .name , self .params , self .locals , entry_bb , self .ret_type )
156
+ return hir .Function (
157
+ self .name ,
158
+ self .params ,
159
+ self .locals ,
160
+ entry_bb ,
161
+ self .ret_type ,
162
+ self .ret_type .hir_type (),
163
+ )
146
164
147
165
148
166
FUNC_STACK : List [FuncTracer ] = []
@@ -160,7 +178,6 @@ def push_to_current_bb[T: hir.Node](node: T) -> T:
160
178
return current_func ().cur_bb ().append (node )
161
179
162
180
163
-
164
181
class Symbolic :
165
182
node : hir .Value
166
183
scope : Scope
@@ -178,7 +195,9 @@ class FlattenedTree:
178
195
children : List ["FlattenedTree" ]
179
196
180
197
def __init__ (
181
- self , metadata : Tuple [Type [Any ], Tuple [Any ], Any ], children : List ["FlattenedTree" ]
198
+ self ,
199
+ metadata : Tuple [Type [Any ], Tuple [Any ], Any ],
200
+ children : List ["FlattenedTree" ],
182
201
):
183
202
self .metadata = metadata
184
203
self .children = children
@@ -214,8 +233,8 @@ def structure(self) -> hir.PyTreeStructure:
214
233
return hir .PyTreeStructure (
215
234
(typ , self .metadata [1 ], self .metadata [2 ]), children
216
235
)
217
-
218
- def collect_jitvars (self ) -> List [' JitVar' ]:
236
+
237
+ def collect_jitvars (self ) -> List [" JitVar" ]:
219
238
"""
220
239
Collect all JitVar instances from the flattened tree
221
240
"""
@@ -275,7 +294,7 @@ class JitVar:
275
294
__symbolic__ : Optional [Symbolic ]
276
295
dtype : type [Any ]
277
296
278
- def __init__ (self , dtype :type [Any ]):
297
+ def __init__ (self , dtype : type [Any ]):
279
298
"""
280
299
Zero-initialize a variable with given data type
281
300
"""
@@ -316,6 +335,14 @@ def from_hir_node[T: JitVar](cls: type[T], node: hir.Value) -> T:
316
335
instance .dtype = cls
317
336
return instance
318
337
338
+ @classmethod
339
+ def hir_type (cls ) -> hir .Type :
340
+ """
341
+ Get the HIR type of the JitVar
342
+ """
343
+ # TODO: handle generic types
344
+ return hir .get_dsl_type (cls ).default ()
345
+
319
346
def symbolic (self ) -> Symbolic :
320
347
"""
321
348
Retrieve the internal symbolic representation of the variable. This is used for internal DSL code generation.
@@ -423,7 +450,8 @@ def unflatten_primitive(tree: FlattenedTree) -> Any:
423
450
424
451
def flatten_list (obj : List [Any ]) -> FlattenedTree :
425
452
return FlattenedTree (
426
- (list , cast (Tuple [Any , ...], tuple ()), None ), [tree_flatten (o , True ) for o in obj ]
453
+ (list , cast (Tuple [Any , ...], tuple ()), None ),
454
+ [tree_flatten (o , True ) for o in obj ],
427
455
)
428
456
429
457
def unflatten_list (tree : FlattenedTree ) -> List [Any ]:
@@ -435,7 +463,8 @@ def unflatten_list(tree: FlattenedTree) -> List[Any]:
435
463
436
464
def flatten_tuple (obj : Tuple [Any , ...]) -> FlattenedTree :
437
465
return FlattenedTree (
438
- (tuple , cast (Tuple [Any , ...], tuple ()), None ), [tree_flatten (o , True ) for o in obj ]
466
+ (tuple , cast (Tuple [Any , ...], tuple ()), None ),
467
+ [tree_flatten (o , True ) for o in obj ],
439
468
)
440
469
441
470
def unflatten_tuple (tree : FlattenedTree ) -> Tuple [Any , ...]:
@@ -491,7 +520,9 @@ def create_intrinsic_node[T: JitVar](
491
520
elif isinstance (a , hir .Value ):
492
521
nodes .append (a )
493
522
else :
494
- raise ValueError (f"Argument [{ i } ] `{ a } ` of type { type (a )} is not a valid DSL variable or HIR node" )
523
+ raise ValueError (
524
+ f"Argument [{ i } ] `{ a } ` of type { type (a )} is not a valid DSL variable or HIR node"
525
+ )
495
526
if ret_type is not None :
496
527
ret_dsl_type = hir .get_dsl_type (ret_type ).default ()
497
528
if ret_dsl_type is None :
@@ -500,25 +531,28 @@ def create_intrinsic_node[T: JitVar](
500
531
ret_dsl_type = hir .UnitType ()
501
532
return push_to_current_bb (hir .Intrinsic (name , nodes , ret_dsl_type ))
502
533
534
+
503
535
def __escape__ (x : Any ) -> Any :
504
536
return x
505
537
538
+
506
539
def __intrinsic_checked__ [T ](
507
540
name : str , arg_types : Sequence [Any ], ret_type : type [T ], * args
508
541
) -> T :
509
542
"""
510
543
Call an intrinsic function with type checking.
511
544
"""
512
- assert len (args ) == len (arg_types ), (
513
- f"Intrinsic { name } expects { len ( arg_types ) } arguments, got { len ( args ) } "
514
- )
545
+ assert len (args ) == len (
546
+ arg_types
547
+ ), f"Intrinsic { name } expects { len ( arg_types ) } arguments, got { len ( args ) } "
515
548
for i , (arg , arg_type ) in enumerate (zip (args , arg_types )):
516
549
if not check_type (arg_type , arg ):
517
550
raise ValueError (
518
551
f"Argument { i } of intrinsic { name } is not of type { arg_type } , got { type (arg )} "
519
552
)
520
553
return __intrinsic__ (name , ret_type , * args )
521
554
555
+
522
556
def __intrinsic__ [T ](name : str , ret_type : type [T ], * args ) -> T :
523
557
"""
524
558
Call an intrinsic function. This function does not check the arguemnts.
@@ -572,19 +606,21 @@ def on_exit(self) -> None:
572
606
"""
573
607
pass
574
608
609
+
575
610
class ScopeGuard :
576
611
def __enter__ (self ) -> Scope :
577
612
"""
578
613
Enter a new scope
579
614
"""
580
615
return current_func ().push_scope ()
581
-
616
+
582
617
def __exit__ (self , exc_type , exc_val , exc_tb ):
583
618
"""
584
619
Exit the current scope
585
620
"""
586
621
current_func ().pop_scope ()
587
622
623
+
588
624
class IfFrame (ControlFlowFrame ):
589
625
static_cond : Optional [bool ]
590
626
true_bb : Optional [hir .BasicBlock ]
@@ -625,12 +661,16 @@ def on_exit(self) -> None:
625
661
cond = self .cond
626
662
assert isinstance (cond , JitVar ), "Condition must be a DSL variable"
627
663
merge_bb = hir .BasicBlock ()
628
- if_stmt = hir .If (cond .symbolic ().node ,
629
- cast (hir .BasicBlock , self .true_bb ),
630
- cast (hir .BasicBlock , self .false_bb ), merge_bb )
664
+ if_stmt = hir .If (
665
+ cond .symbolic ().node ,
666
+ cast (hir .BasicBlock , self .true_bb ),
667
+ cast (hir .BasicBlock , self .false_bb ),
668
+ merge_bb ,
669
+ )
631
670
push_to_current_bb (if_stmt )
632
671
current_func ().set_cur_bb (merge_bb )
633
672
673
+
634
674
class ControlFrameGuard [T : ControlFlowFrame ]:
635
675
cf_type : type [T ]
636
676
args : Tuple [Any , ...]
@@ -707,31 +747,52 @@ def __exit__(self, exc_type, exc_val, exc_tb):
707
747
"GtE" : ["__ge__" , "__le__" ],
708
748
}
709
749
750
+
751
+ class LineTable :
752
+ span_of_line : Dict [int , hir .Span ]
753
+
754
+
710
755
class TraceContext :
711
756
cf_frame : ControlFlowFrame
712
757
is_top_level : bool
713
758
top_level_func : Optional [hir .Function ]
759
+ line_table : Optional [LineTable ] # for better error reporting
760
+ current_line : Optional [int ] = None
714
761
715
762
def __init__ (self , is_top_level ):
716
763
self .cf_frame = ControlFlowFrame (parent = None )
717
764
self .is_top_level = is_top_level
718
765
self .top_level_func = None
719
766
767
+ def set_line_table (self , line_table : LineTable ) -> None :
768
+ self .line_table = line_table
769
+
770
+ def set_current_line (self , line : int ) -> None :
771
+ self .current_line = line
772
+
773
+ def current_span (self ) -> hir .Span | None :
774
+ """
775
+ Get the current span for error reporting
776
+ """
777
+ if self .line_table is not None and self .current_line is not None :
778
+ return self .line_table .span_of_line .get (self .current_line , None )
779
+ return None
780
+
720
781
def is_parent_static (self ) -> bool :
721
782
return self .cf_frame .is_static
722
783
723
784
def if_ (self , cond : Any ) -> ControlFrameGuard [IfFrame ]:
724
785
return ControlFrameGuard (self , IfFrame , cond )
725
-
786
+
726
787
def scope (self ) -> ScopeGuard :
727
788
return ScopeGuard ()
728
789
729
790
def return_ (self , expr : JitVar ) -> None :
730
791
"""
731
792
Return a value from the current function
732
793
"""
733
- ty = expr . _symbolic_type ()
734
- current_func ().check_return_type (ty )
794
+ assert isinstance ( expr , JitVar ), "Return expression must be a DSL variable"
795
+ current_func ().check_return_type (type ( expr )) # TODO: handle generics
735
796
push_to_current_bb (hir .Return (expr .symbolic ().node ))
736
797
737
798
def redirect_binary (self , op , x , y ):
@@ -745,7 +806,7 @@ def redirect_binary(self, op, x, y):
745
806
raise ValueError (
746
807
f"Binary operation { op } not supported for { type (x )} and { type (y )} "
747
808
)
748
-
809
+
749
810
def redirect_cmp (self , op , x , y ):
750
811
op , rop = CMP_OP_TO_METHOD_NAMES [op ]
751
812
if hasattr (x , op ):
@@ -758,11 +819,11 @@ def redirect_cmp(self, op, x, y):
758
819
)
759
820
760
821
def redirect_call (self , f , * args , ** kwargs ):
761
- return f (* args , ** kwargs , __lc_ctx__ = self ) # TODO: shoould not always pass self
822
+ return f (* args , ** kwargs , __lc_ctx__ = self ) # TODO: shoould not always pass self
762
823
763
824
def intrinsic (self , f , * args , ** kwargs ):
764
825
return __intrinsic__ (f , * args , ** kwargs )
765
-
826
+
766
827
def intrinsic_checked (self , f , arg_types , ret_type , * args ):
767
828
return __intrinsic_checked__ (f , arg_types , ret_type , * args )
768
829
@@ -837,7 +898,7 @@ def _invoke_function_tracer(
837
898
trace_ctx = TraceContext (False )
838
899
839
900
# args is Type | object
840
- func_tracer = FuncTracer (f .__name__ .replace ('.' , '_' ), globalns )
901
+ func_tracer = FuncTracer (f .__name__ .replace ("." , "_" ), globalns )
841
902
FUNC_STACK .append (func_tracer )
842
903
try :
843
904
args_vars , kwargs_vars , jit_vars = _encode_func_args (args )
@@ -852,7 +913,7 @@ class KernelTracer:
852
913
top_level_tracer : FuncTracer
853
914
854
915
def __init__ (self , func_globals : Dict [str , Any ]):
855
- self .top_level_tracer = FuncTracer (' __kernel__' , func_globals )
916
+ self .top_level_tracer = FuncTracer (" __kernel__" , func_globals )
856
917
857
918
def __enter__ (self ) -> FuncTracer :
858
919
FUNC_STACK .append (self .top_level_tracer )
0 commit comments