Skip to content

Commit 9bce61e

Browse files
committed
fixed source location in stacktrace
1 parent 39ff76d commit 9bce61e

File tree

3 files changed

+65
-26
lines changed

3 files changed

+65
-26
lines changed

luisa_lang/ast_rewrite.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ast import NodeTransformer
33
import copy
44
from typing import Callable, Any, List, Set, cast
5-
from luisa_lang.utils import checked_cast, retrieve_ast_and_filename, NestedHashMap
5+
from luisa_lang.utils import Span, checked_cast, retrieve_ast_and_filename, NestedHashMap
66

77
"""
88
Rewrite rules:
@@ -163,17 +163,21 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
163163
return node
164164

165165
def visit_Name(self, node: ast.Name) -> Any:
166+
span = Span.from_ast(node)
167+
assert span is not None
166168
# rewrite to __lc_ctx__.name
167-
return ast.Subscript(
169+
return span.apply_to_ast(ast.Subscript(
168170
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
169171
slice=ast.Constant(value=node.id),
170172
ctx=node.ctx,
171-
)
173+
))
172174

173175
def visit_Assign(self, node: ast.Assign) -> Any:
174176
return self.generic_visit(node)
175177

176178
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
179+
span = Span.from_ast(node)
180+
assert span is not None
177181
target = checked_cast(ast.expr, self.visit(node.target))
178182
assert isinstance(target, (ast.Name, ast.Subscript, ast.Attribute))
179183
target.ctx = ast.Load()
@@ -193,81 +197,93 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
193197
target = copy.deepcopy(target)
194198
target.ctx = ast.Store()
195199
assign = ast.Assign(targets=[target], value=self.visit(node.value))
200+
span.apply_to_ast(anno)
201+
span.apply_to_ast(assign)
196202
return [anno, assign]
197203

198204
def visit_Call(self, node: ast.Call) -> Any:
205+
span = Span.from_ast(node)
206+
assert span is not None
199207
# first check if it is of form `__intrinsic__(...)`
200208
if isinstance(node.func, ast.Name):
201209
if node.func.id in NO_REWRITE_FUNCTIONS:
202210
return node
203211
if node.func.id == "__intrinsic__" or node.func.id == "__intrinsic_checked__":
204212
# rewrite to __lc_ctx__.intrinsic(...)
205-
return ast.Call(
213+
return span.apply_to_ast(ast.Call(
206214
func=ast.Attribute(
207215
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
208216
attr=node.func.id[2:-2],
209217
ctx=ast.Load(),
210218
),
211219
args=[self.visit(arg) for arg in node.args],
212220
keywords=[self.visit(kw) for kw in node.keywords],
213-
)
221+
))
214222
# rewrite to __lc_ctx__.redirect_call(func, args...)
215223
func = self.visit(node.func)
216224
args = [self.visit(arg) for arg in node.args]
217225
keywords = [self.visit(kw) for kw in node.keywords]
218-
return ast.Call(
226+
return span.apply_to_ast(ast.Call(
219227
func=ast.Attribute(
220228
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
221229
attr="redirect_call",
222230
ctx=ast.Load(),
223231
),
224232
args=[func] + args,
225233
keywords=keywords,
226-
)
234+
))
227235

228236
def visit_BinOp(self, node: ast.BinOp) -> Any:
237+
span = Span.from_ast(node)
238+
assert span is not None
229239
lhs = self.visit(node.left)
230240
rhs = self.visit(node.right)
231-
return ast.Call(
241+
return span.apply_to_ast(ast.Call(
232242
func=ast.Attribute(
233243
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
234244
attr="redirect_binary",
235245
ctx=ast.Load(),
236246
),
237247
args=[ast.Constant(value=type(node.op).__name__), lhs, rhs],
238248
keywords=[],
239-
)
249+
))
240250

241251
def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
252+
span = Span.from_ast(node)
253+
assert span is not None
242254
operand = self.visit(node.operand)
243-
return ast.Call(
255+
return span.apply_to_ast(ast.Call(
244256
func=ast.Attribute(
245257
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
246258
attr="redirect_unary",
247259
ctx=ast.Load(),
248260
),
249261
args=[ast.Constant(value=type(node.op).__name__), operand],
250262
keywords=[],
251-
)
263+
))
252264

253265
def visit_Compare(self, node: ast.Compare) -> Any:
266+
span = Span.from_ast(node)
267+
assert span is not None
254268
if len(node.ops) != 1 or len(node.comparators) != 1:
255269
raise NotImplementedError("Only single comparison is supported")
256270
left = self.visit(node.left)
257271
right = self.visit(node.comparators[0])
258-
return ast.Call(
272+
return span.apply_to_ast(ast.Call(
259273
func=ast.Attribute(
260274
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
261275
attr="redirect_binary",
262276
ctx=ast.Load(),
263277
),
264278
args=[ast.Constant(value=type(node.ops[0]).__name__), left, right],
265279
keywords=[],
266-
)
280+
))
267281

268282
def visit_Subscript(self, node: ast.Subscript) -> Any:
283+
span = Span.from_ast(node)
284+
assert span is not None
269285
value = self.visit(node.value)
270-
return ast.Subscript(
286+
return span.apply_to_ast(ast.Subscript(
271287
value=ast.Call(
272288
func=ast.Attribute(
273289
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
@@ -279,11 +295,13 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
279295
),
280296
slice=node.slice,
281297
ctx=node.ctx,
282-
)
298+
))
283299

284300
def visit_Attribute(self, node: ast.Attribute) -> Any:
301+
span = Span.from_ast(node)
302+
assert span is not None
285303
value = self.visit(node.value)
286-
return ast.Attribute(
304+
return span.apply_to_ast(ast.Attribute(
287305
value=ast.Call(
288306
func=ast.Attribute(
289307
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
@@ -295,9 +313,11 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
295313
),
296314
attr=node.attr,
297315
ctx=node.ctx,
298-
)
316+
))
299317

300318
def visit_If(self, node: ast.If) -> Any:
319+
span = Span.from_ast(node)
320+
assert span is not None
301321
if_id = self.new_id() + "_if"
302322
with_item = ast.withitem(
303323
context_expr=ast.Call(
@@ -361,10 +381,12 @@ def visit_If(self, node: ast.If) -> Any:
361381
]),
362382
orelse=[],
363383
)
364-
with_stmt = ast.With(items=[with_item], body=[true_branch, false_branch])
384+
with_stmt = span.apply_to_ast(ast.With(items=[with_item], body=[true_branch, false_branch]))
365385
return with_stmt
366386

367387
def visit_Return(self, node: ast.Return) -> Any:
388+
span = Span.from_ast(node)
389+
assert span is not None
368390
self.return_cnt += 1
369391
if self.is_tracing:
370392
if self.return_cnt > 1:
@@ -380,7 +402,7 @@ def visit_Return(self, node: ast.Return) -> Any:
380402
tmp = self.visit(node.value)
381403
assert isinstance(tmp, ast.expr)
382404
ret_value = tmp
383-
return ast.If(
405+
return span.apply_to_ast(ast.If(
384406
test=ast.Call(
385407
func=ast.Attribute(
386408
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
@@ -407,10 +429,12 @@ def visit_Return(self, node: ast.Return) -> Any:
407429
)
408430
],
409431
),
410-
)
432+
))
411433

412434
def visit_Break(self, node: ast.Break) -> Any:
413-
return ast.If(
435+
span = Span.from_ast(node)
436+
assert span is not None
437+
return span.apply_to_ast(ast.If(
414438
test=ast.Call(
415439
func=ast.Attribute(
416440
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
@@ -437,10 +461,12 @@ def visit_Break(self, node: ast.Break) -> Any:
437461
)
438462
],
439463
),
440-
)
464+
))
441465

442466
def visit_Continue(self, node: ast.Continue) -> Any:
443-
return ast.If(
467+
span = Span.from_ast(node)
468+
assert span is not None
469+
return span.apply_to_ast(ast.If(
444470
test=ast.Call(
445471
func=ast.Attribute(
446472
value=ast.Name(id="__lc_ctx__", ctx=ast.Load()),
@@ -467,15 +493,14 @@ def visit_Continue(self, node: ast.Continue) -> Any:
467493
)
468494
],
469495
),
470-
)
496+
))
471497

472498

473499
def rewrite_function[F: Callable[..., Any]](f: F, decorator_name: str) -> F:
474500
tree, filename = retrieve_ast_and_filename(f)
475501
tree = FuncRewriter(decorator_name, filename).visit(tree)
476502
ast.fix_missing_locations(tree)
477-
# print(ast.unparse(tree))
478-
code = compile(tree, filename="<ast>", mode="exec")
503+
code = compile(tree, filename=filename, mode="exec")
479504
local_dict: dict[Any, Any] = {}
480505
exec(code, f.__globals__, local_dict)
481506
rewrote_f = local_dict[f.__name__]

luisa_lang/lang_runtime.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,9 @@ def create_intrinsic_node[T: JitVar](
533533

534534

535535
def __escape__(x: Any) -> Any:
536+
"""
537+
A marker used by the compiler frontend to prevent an expression from being rewritten.
538+
"""
536539
return x
537540

538541

luisa_lang/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,17 @@ def from_ast(ast: ast.AST) -> Optional["Span"]:
120120
end=(getattr(ast, "end_lineno", 0), getattr(ast, "end_col_offset", 0)),
121121
)
122122

123+
def apply_to_ast(self, ast: ast.AST) -> ast.AST:
124+
"""
125+
Apply the span to the given AST node.
126+
"""
127+
setattr(ast, "lineno", self.start[0])
128+
setattr(ast, "col_offset", self.start[1])
129+
setattr(ast, "end_lineno", self.end[0])
130+
setattr(ast, "end_col_offset", self.end[1])
131+
if self.file is not None:
132+
setattr(ast, "source_file", self.file)
133+
return ast
123134

124135
def print_yellow(message: str) -> None:
125136
print(f"\033[33m{message}\033[0m")

0 commit comments

Comments
 (0)