2
2
from ast import NodeTransformer
3
3
import copy
4
4
from typing import Callable , Any , List , Set , cast
5
- from luisa_lang .utils import checked_cast , retrieve_ast_and_filename , NestedHashMap
5
+ from luisa_lang .utils import Span , checked_cast , retrieve_ast_and_filename , NestedHashMap
6
6
7
7
"""
8
8
Rewrite rules:
@@ -163,17 +163,21 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
163
163
return node
164
164
165
165
def visit_Name (self , node : ast .Name ) -> Any :
166
+ span = Span .from_ast (node )
167
+ assert span is not None
166
168
# rewrite to __lc_ctx__.name
167
- return ast .Subscript (
169
+ return span . apply_to_ast ( ast .Subscript (
168
170
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
169
171
slice = ast .Constant (value = node .id ),
170
172
ctx = node .ctx ,
171
- )
173
+ ))
172
174
173
175
def visit_Assign (self , node : ast .Assign ) -> Any :
174
176
return self .generic_visit (node )
175
177
176
178
def visit_AnnAssign (self , node : ast .AnnAssign ) -> Any :
179
+ span = Span .from_ast (node )
180
+ assert span is not None
177
181
target = checked_cast (ast .expr , self .visit (node .target ))
178
182
assert isinstance (target , (ast .Name , ast .Subscript , ast .Attribute ))
179
183
target .ctx = ast .Load ()
@@ -193,81 +197,93 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
193
197
target = copy .deepcopy (target )
194
198
target .ctx = ast .Store ()
195
199
assign = ast .Assign (targets = [target ], value = self .visit (node .value ))
200
+ span .apply_to_ast (anno )
201
+ span .apply_to_ast (assign )
196
202
return [anno , assign ]
197
203
198
204
def visit_Call (self , node : ast .Call ) -> Any :
205
+ span = Span .from_ast (node )
206
+ assert span is not None
199
207
# first check if it is of form `__intrinsic__(...)`
200
208
if isinstance (node .func , ast .Name ):
201
209
if node .func .id in NO_REWRITE_FUNCTIONS :
202
210
return node
203
211
if node .func .id == "__intrinsic__" or node .func .id == "__intrinsic_checked__" :
204
212
# rewrite to __lc_ctx__.intrinsic(...)
205
- return ast .Call (
213
+ return span . apply_to_ast ( ast .Call (
206
214
func = ast .Attribute (
207
215
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
208
216
attr = node .func .id [2 :- 2 ],
209
217
ctx = ast .Load (),
210
218
),
211
219
args = [self .visit (arg ) for arg in node .args ],
212
220
keywords = [self .visit (kw ) for kw in node .keywords ],
213
- )
221
+ ))
214
222
# rewrite to __lc_ctx__.redirect_call(func, args...)
215
223
func = self .visit (node .func )
216
224
args = [self .visit (arg ) for arg in node .args ]
217
225
keywords = [self .visit (kw ) for kw in node .keywords ]
218
- return ast .Call (
226
+ return span . apply_to_ast ( ast .Call (
219
227
func = ast .Attribute (
220
228
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
221
229
attr = "redirect_call" ,
222
230
ctx = ast .Load (),
223
231
),
224
232
args = [func ] + args ,
225
233
keywords = keywords ,
226
- )
234
+ ))
227
235
228
236
def visit_BinOp (self , node : ast .BinOp ) -> Any :
237
+ span = Span .from_ast (node )
238
+ assert span is not None
229
239
lhs = self .visit (node .left )
230
240
rhs = self .visit (node .right )
231
- return ast .Call (
241
+ return span . apply_to_ast ( ast .Call (
232
242
func = ast .Attribute (
233
243
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
234
244
attr = "redirect_binary" ,
235
245
ctx = ast .Load (),
236
246
),
237
247
args = [ast .Constant (value = type (node .op ).__name__ ), lhs , rhs ],
238
248
keywords = [],
239
- )
249
+ ))
240
250
241
251
def visit_UnaryOp (self , node : ast .UnaryOp ) -> Any :
252
+ span = Span .from_ast (node )
253
+ assert span is not None
242
254
operand = self .visit (node .operand )
243
- return ast .Call (
255
+ return span . apply_to_ast ( ast .Call (
244
256
func = ast .Attribute (
245
257
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
246
258
attr = "redirect_unary" ,
247
259
ctx = ast .Load (),
248
260
),
249
261
args = [ast .Constant (value = type (node .op ).__name__ ), operand ],
250
262
keywords = [],
251
- )
263
+ ))
252
264
253
265
def visit_Compare (self , node : ast .Compare ) -> Any :
266
+ span = Span .from_ast (node )
267
+ assert span is not None
254
268
if len (node .ops ) != 1 or len (node .comparators ) != 1 :
255
269
raise NotImplementedError ("Only single comparison is supported" )
256
270
left = self .visit (node .left )
257
271
right = self .visit (node .comparators [0 ])
258
- return ast .Call (
272
+ return span . apply_to_ast ( ast .Call (
259
273
func = ast .Attribute (
260
274
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
261
275
attr = "redirect_binary" ,
262
276
ctx = ast .Load (),
263
277
),
264
278
args = [ast .Constant (value = type (node .ops [0 ]).__name__ ), left , right ],
265
279
keywords = [],
266
- )
280
+ ))
267
281
268
282
def visit_Subscript (self , node : ast .Subscript ) -> Any :
283
+ span = Span .from_ast (node )
284
+ assert span is not None
269
285
value = self .visit (node .value )
270
- return ast .Subscript (
286
+ return span . apply_to_ast ( ast .Subscript (
271
287
value = ast .Call (
272
288
func = ast .Attribute (
273
289
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
@@ -279,11 +295,13 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
279
295
),
280
296
slice = node .slice ,
281
297
ctx = node .ctx ,
282
- )
298
+ ))
283
299
284
300
def visit_Attribute (self , node : ast .Attribute ) -> Any :
301
+ span = Span .from_ast (node )
302
+ assert span is not None
285
303
value = self .visit (node .value )
286
- return ast .Attribute (
304
+ return span . apply_to_ast ( ast .Attribute (
287
305
value = ast .Call (
288
306
func = ast .Attribute (
289
307
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
@@ -295,9 +313,11 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
295
313
),
296
314
attr = node .attr ,
297
315
ctx = node .ctx ,
298
- )
316
+ ))
299
317
300
318
def visit_If (self , node : ast .If ) -> Any :
319
+ span = Span .from_ast (node )
320
+ assert span is not None
301
321
if_id = self .new_id () + "_if"
302
322
with_item = ast .withitem (
303
323
context_expr = ast .Call (
@@ -361,10 +381,12 @@ def visit_If(self, node: ast.If) -> Any:
361
381
]),
362
382
orelse = [],
363
383
)
364
- with_stmt = ast .With (items = [with_item ], body = [true_branch , false_branch ])
384
+ with_stmt = span . apply_to_ast ( ast .With (items = [with_item ], body = [true_branch , false_branch ]) )
365
385
return with_stmt
366
386
367
387
def visit_Return (self , node : ast .Return ) -> Any :
388
+ span = Span .from_ast (node )
389
+ assert span is not None
368
390
self .return_cnt += 1
369
391
if self .is_tracing :
370
392
if self .return_cnt > 1 :
@@ -380,7 +402,7 @@ def visit_Return(self, node: ast.Return) -> Any:
380
402
tmp = self .visit (node .value )
381
403
assert isinstance (tmp , ast .expr )
382
404
ret_value = tmp
383
- return ast .If (
405
+ return span . apply_to_ast ( ast .If (
384
406
test = ast .Call (
385
407
func = ast .Attribute (
386
408
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
@@ -407,10 +429,12 @@ def visit_Return(self, node: ast.Return) -> Any:
407
429
)
408
430
],
409
431
),
410
- )
432
+ ))
411
433
412
434
def visit_Break (self , node : ast .Break ) -> Any :
413
- return ast .If (
435
+ span = Span .from_ast (node )
436
+ assert span is not None
437
+ return span .apply_to_ast (ast .If (
414
438
test = ast .Call (
415
439
func = ast .Attribute (
416
440
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
@@ -437,10 +461,12 @@ def visit_Break(self, node: ast.Break) -> Any:
437
461
)
438
462
],
439
463
),
440
- )
464
+ ))
441
465
442
466
def visit_Continue (self , node : ast .Continue ) -> Any :
443
- return ast .If (
467
+ span = Span .from_ast (node )
468
+ assert span is not None
469
+ return span .apply_to_ast (ast .If (
444
470
test = ast .Call (
445
471
func = ast .Attribute (
446
472
value = ast .Name (id = "__lc_ctx__" , ctx = ast .Load ()),
@@ -467,15 +493,14 @@ def visit_Continue(self, node: ast.Continue) -> Any:
467
493
)
468
494
],
469
495
),
470
- )
496
+ ))
471
497
472
498
473
499
def rewrite_function [F : Callable [..., Any ]](f : F , decorator_name : str ) -> F :
474
500
tree , filename = retrieve_ast_and_filename (f )
475
501
tree = FuncRewriter (decorator_name , filename ).visit (tree )
476
502
ast .fix_missing_locations (tree )
477
- # print(ast.unparse(tree))
478
- code = compile (tree , filename = "<ast>" , mode = "exec" )
503
+ code = compile (tree , filename = filename , mode = "exec" )
479
504
local_dict : dict [Any , Any ] = {}
480
505
exec (code , f .__globals__ , local_dict )
481
506
rewrote_f = local_dict [f .__name__ ]
0 commit comments