Skip to content

Commit ddc2805

Browse files
committed
need to rethink about this...
1 parent e37ccf1 commit ddc2805

File tree

2 files changed

+105
-75
lines changed

2 files changed

+105
-75
lines changed

luisa_lang/hir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,7 @@ class Assign(Node):
10721072
value: Value
10731073

10741074
def __init__(self, ref: Ref, value: Value, span: Optional[Span] = None) -> None:
1075+
assert not isinstance(value.type, (FunctionType, TypeConstructorType))
10751076
super().__init__(span)
10761077
self.ref = ref
10771078
self.value = value

luisa_lang/parse.py

Lines changed: 104 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -352,48 +352,10 @@ def get_index_type(self, span: Optional[hir.Span], base: hir.Type, index: hir.Va
352352
ty = base.member(hir.DynamicIndex())
353353
return ty
354354

355-
def parse_access_ref(self, expr: ast.Subscript | ast.Attribute) -> hir.Ref:
355+
def parse_access_ref(self, expr: ast.Subscript | ast.Attribute) -> hir.Ref | hir.TypeValue:
356356
span = hir.Span.from_ast(expr)
357357
if isinstance(expr, ast.Subscript):
358358
value = self.parse_ref(expr.value)
359-
index = self.parse_expr(expr.slice)
360-
assert isinstance(value, hir.Ref) and isinstance(index, hir.Value)
361-
index = self.convert_to_value(index, span)
362-
assert value.type
363-
index_ty = self.get_index_type(span, value.type, index)
364-
if index_ty is not None:
365-
return self.cur_bb().append(hir.IndexRef(value, index, type=index_ty, span=span))
366-
else:
367-
# check __getitem__
368-
if (method := value.type.method("__getitem__")) and method:
369-
ret = self.parse_call_impl(
370-
span, method, [value, index])
371-
if isinstance(ret, hir.TemplateMatchingError):
372-
raise hir.TypeInferenceError(
373-
expr, f"error calling __getitem__: {ret.message}")
374-
return self.cur_bb().append(hir.LocalRef(ret))
375-
else:
376-
raise hir.TypeInferenceError(
377-
expr, f"indexing not supported for type {value.type}")
378-
elif isinstance(expr, ast.Attribute):
379-
value = self.parse_ref(expr.value)
380-
assert isinstance(value, hir.Ref)
381-
attr_name = expr.attr
382-
assert value.type
383-
member_ty = value.type.member(attr_name)
384-
if not member_ty:
385-
raise hir.ParsingError(
386-
expr, f"member {attr_name} not found in type {value.type}")
387-
if isinstance(member_ty, hir.FunctionType):
388-
raise hir.ParsingError(
389-
expr, f"method {attr_name} cannot be used as reference")
390-
return self.cur_bb().append(hir.MemberRef(value, attr_name, type=member_ty, span=span))
391-
raise NotImplementedError() # unreachable
392-
393-
def parse_access(self, expr: ast.Subscript | ast.Attribute) -> hir.Value | ComptimeValue:
394-
span = hir.Span.from_ast(expr)
395-
if isinstance(expr, ast.Subscript):
396-
value = self.parse_expr(expr.value)
397359
if isinstance(value, ComptimeValue):
398360
raise hir.ParsingError(
399361
expr, "attempt to access comptime value in DSL code; wrap it in lc.comptime(...) if you intead to use it as a compile-time expression")
@@ -418,13 +380,13 @@ def parse_type_arg(expr: ast.expr) -> hir.Type:
418380
value.type.inner, hir.ParametricType)
419381
return hir.TypeValue(
420382
hir.BoundType(value.type.inner, type_args, value.type.inner.instantiate(type_args)))
421-
422-
assert value.type
423383
index = self.parse_expr(expr.slice)
384+
assert isinstance(value, hir.Ref) and isinstance(index, hir.Value)
424385
index = self.convert_to_value(index, span)
386+
assert value.type
425387
index_ty = self.get_index_type(span, value.type, index)
426388
if index_ty is not None:
427-
return self.cur_bb().append(hir.Index(value, index, type=index_ty, span=span))
389+
return self.cur_bb().append(hir.IndexRef(value, index, type=index_ty, span=span))
428390
else:
429391
# check __getitem__
430392
if (method := value.type.method("__getitem__")) and method:
@@ -433,28 +395,90 @@ def parse_type_arg(expr: ast.expr) -> hir.Type:
433395
if isinstance(ret, hir.TemplateMatchingError):
434396
raise hir.TypeInferenceError(
435397
expr, f"error calling __getitem__: {ret.message}")
436-
return ret
398+
return self.cur_bb().append(hir.LocalRef(ret))
437399
else:
438400
raise hir.TypeInferenceError(
439401
expr, f"indexing not supported for type {value.type}")
440402
elif isinstance(expr, ast.Attribute):
441-
def do() -> ComptimeValue | hir.Value:
442-
value = self.parse_ref(expr.value)
443-
attr_name = expr.attr
444-
if isinstance(value, ComptimeValue):
445-
return ComptimeValue(getattr(value.value, attr_name), None)
446-
assert value.type
447-
member_ty = value.type.member(attr_name)
448-
if not member_ty:
449-
raise hir.ParsingError(
450-
expr, f"member {attr_name} not found in type {value.type}")
451-
if isinstance(member_ty, hir.FunctionType):
452-
if not isinstance(value, hir.TypeValue):
453-
member_ty.bound_object = value
454-
return self.cur_bb().append(hir.Member(self.convert_to_value(value, span), attr_name, type=member_ty, span=span))
455-
return do()
403+
value = self.parse_ref(expr.value)
404+
assert isinstance(value, hir.Ref)
405+
attr_name = expr.attr
406+
assert value.type
407+
member_ty = value.type.member(attr_name)
408+
if not member_ty:
409+
raise hir.ParsingError(
410+
expr, f"member {attr_name} not found in type {value.type}")
411+
if isinstance(member_ty, hir.FunctionType):
412+
if not isinstance(value, hir.TypeValue):
413+
member_ty.bound_object = value
414+
return self.cur_bb().append(hir.MemberRef(value, attr_name, type=member_ty, span=span))
456415
raise NotImplementedError() # unreachable
457416

417+
# def parse_access(self, expr: ast.Subscript | ast.Attribute) -> hir.Value | ComptimeValue:
418+
# span = hir.Span.from_ast(expr)
419+
# if isinstance(expr, ast.Subscript):
420+
# value = self.parse_expr(expr.value)
421+
# if isinstance(value, ComptimeValue):
422+
# raise hir.ParsingError(
423+
# expr, "attempt to access comptime value in DSL code; wrap it in lc.comptime(...) if you intead to use it as a compile-time expression")
424+
# if isinstance(value, hir.TypeValue):
425+
# type_args: List[hir.Type] = []
426+
427+
# def parse_type_arg(expr: ast.expr) -> hir.Type:
428+
# type_annotation = self.eval_expr(expr)
429+
# type_hint = classinfo.parse_type_hint(type_annotation)
430+
# ty = self.parse_type(type_hint)
431+
# assert ty
432+
# return ty
433+
434+
# match expr.slice:
435+
# case ast.Tuple():
436+
# for e in expr.slice.elts:
437+
# type_args.append(parse_type_arg(e))
438+
# case _:
439+
# type_args.append(parse_type_arg(expr.slice))
440+
# # print(f"Type args: {type_args}")
441+
# assert isinstance(value.type, hir.TypeConstructorType) and isinstance(
442+
# value.type.inner, hir.ParametricType)
443+
# return hir.TypeValue(
444+
# hir.BoundType(value.type.inner, type_args, value.type.inner.instantiate(type_args)))
445+
446+
# assert value.type
447+
# index = self.parse_expr(expr.slice)
448+
# index = self.convert_to_value(index, span)
449+
# index_ty = self.get_index_type(span, value.type, index)
450+
# if index_ty is not None:
451+
# return self.cur_bb().append(hir.Index(value, index, type=index_ty, span=span))
452+
# else:
453+
# # check __getitem__
454+
# if (method := value.type.method("__getitem__")) and method:
455+
# ret = self.parse_call_impl(
456+
# span, method, [value, index])
457+
# if isinstance(ret, hir.TemplateMatchingError):
458+
# raise hir.TypeInferenceError(
459+
# expr, f"error calling __getitem__: {ret.message}")
460+
# return ret
461+
# else:
462+
# raise hir.TypeInferenceError(
463+
# expr, f"indexing not supported for type {value.type}")
464+
# elif isinstance(expr, ast.Attribute):
465+
# def do() -> ComptimeValue | hir.Value:
466+
# value = self.parse_ref(expr.value)
467+
# attr_name = expr.attr
468+
# if isinstance(value, ComptimeValue):
469+
# return ComptimeValue(getattr(value.value, attr_name), None)
470+
# assert value.type
471+
# member_ty = value.type.member(attr_name)
472+
# if not member_ty:
473+
# raise hir.ParsingError(
474+
# expr, f"member {attr_name} not found in type {value.type}")
475+
# if isinstance(member_ty, hir.FunctionType):
476+
# if not isinstance(value, hir.TypeValue):
477+
# member_ty.bound_object = value
478+
# return self.cur_bb().append(hir.Member(self.convert_to_value(value, span), attr_name, type=member_ty, span=span))
479+
# return do()
480+
# raise NotImplementedError() # unreachable
481+
458482
def parse_call_impl(self, span: hir.Span | None, f: hir.FunctionLike | hir.FunctionTemplate, args: List[hir.Value | hir.Ref]) -> hir.Value | hir.TemplateMatchingError:
459483
if isinstance(f, hir.FunctionTemplate):
460484
if f.is_generic:
@@ -620,11 +644,10 @@ def make_int(i: int) -> hir.Value:
620644
raise RuntimeError(f"Unsupported special function {f}")
621645

622646
def parse_call(self, expr: ast.Call) -> hir.Value | ComptimeValue:
623-
func = self.parse_expr(expr.func) # TODO: this should be a parse_ref
647+
func: hir.Ref | ComptimeValue | hir.TypeValue | hir.Value = self.parse_ref(expr.func) # TODO: this should be a parse_ref
624648
span = hir.Span.from_ast(expr)
625-
if isinstance(func, hir.Ref):
626-
raise hir.ParsingError(expr, f"function expected")
627-
elif isinstance(func, ComptimeValue):
649+
650+
if isinstance(func, ComptimeValue):
628651
if func.value in SPECIAL_FUNCTIONS:
629652
return self.handle_special_functions(func.value, expr)
630653
func = self.try_convert_comptime_value(
@@ -638,11 +661,11 @@ def collect_args() -> List[hir.Value | hir.Ref]:
638661
arg, hir.Span.from_ast(expr.args[i]))
639662
return cast(List[hir.Value | hir.Ref], args)
640663

641-
if isinstance(func.type, hir.TypeConstructorType):
664+
if isinstance(func, hir.TypeValue):
642665
# TypeConstructorType is unique for each type
643666
# so if any value has this type, it must be referring to the same underlying type
644667
# even if it comes from a very complex expression, it's still fine
645-
cls = func.type.inner
668+
cls = func.inner_type()
646669
assert cls
647670
if isinstance(cls, hir.ParametricType):
648671
raise hir.ParsingError(
@@ -693,7 +716,6 @@ def collect_args() -> List[hir.Value | hir.Ref]:
693716
# raise hir.ParsingError(expr, f"function expected but got {func}")
694717
# else:
695718
# func_like = func.func
696-
697719

698720
# def parse_compare(self, expr: ast.Compare) -> hir.Value | ComptimeValue:
699721
# cmpop_to_str: Dict[type, str] = {
@@ -786,32 +808,37 @@ def infer_binop(name: str, rname: str) -> hir.Value:
786808
raise e from e
787809
return infer_binop(ops[0], ops[1])
788810

789-
def parse_ref(self, expr: ast.expr, new_var_hint: NewVarHint = False) -> hir.Ref | ComptimeValue:
811+
def parse_ref(self, expr: ast.expr, new_var_hint: NewVarHint = False) -> hir.Ref | ComptimeValue | hir.TypeValue:
790812
match expr:
791813
case ast.Name():
792814
ret = self.parse_name(expr, new_var_hint)
793815
if isinstance(ret, (hir.Value)):
794-
raise hir.ParsingError(
795-
expr, f"{type(ret)} cannot be used as reference")
816+
if isinstance(ret.type, hir.TypeConstructorType):
817+
assert isinstance(ret, hir.TypeValue)
818+
return ret
819+
# raise hir.ParsingError(
820+
# expr, f"{type(ret)} cannot be used as reference")
821+
assert ret.type
822+
tmp = self.cur_bb().append(hir.Alloca(ret.type, hir.Span.from_ast(expr)))
823+
self.cur_bb().append(hir.Assign(tmp, ret))
824+
return tmp
796825
return ret
797826
case ast.Subscript() | ast.Attribute():
798827
return self.parse_access_ref(expr)
799828
case _:
800829
raise hir.ParsingError(
801830
expr, f"expression {ast.dump(expr)} cannot be parsed as reference")
802831

803-
# def parse_assignment_targets(self, targets: List[ast.expr], new_var_hint: NewVarHint) -> List[hir.Ref]:
804-
# return [self.parse_ref(t, new_var_hint) for t in targets]
805-
806-
# def assign(self, targets: List[hir.Ref], values: hir.Value | ComptimeValue) -> None:
807-
# pass
808-
809832
def parse_multi_assignment(self,
810833
targets: List[ast.expr],
811834
anno_ty_fn: List[Optional[Callable[..., hir.Type | None]]],
812835
values: hir.Value | ComptimeValue) -> None:
813836
if isinstance(values, ComptimeValue):
814837
parsed_targets = [self.parse_ref(t, 'comptime') for t in targets]
838+
for i in range(len(parsed_targets)):
839+
if isinstance(parsed_targets[i], hir.TypeValue):
840+
raise hir.ParsingError(
841+
targets[i], "types cannot be reassigned")
815842

816843
def do_assign(target: hir.Ref | ComptimeValue, value: ComptimeValue, i: int) -> None:
817844
span = hir.Span.from_ast(targets[i])
@@ -825,10 +852,12 @@ def do_assign(target: hir.Ref | ComptimeValue, value: ComptimeValue, i: int) ->
825852
raise hir.ParsingError(
826853
targets[0], f"expected {len(parsed_targets)} values to unpack, got {len(values.value)}")
827854
for i, t in enumerate(parsed_targets):
855+
assert isinstance(t, (hir.Ref, ComptimeValue))
828856
do_assign(t, values.value[i],
829857
i)
830858
else:
831859
t = parsed_targets[0]
860+
assert isinstance(t, (hir.Ref, ComptimeValue))
832861
do_assign(t, values, 0)
833862
else:
834863
parsed_targets = [self.parse_ref(t, 'dsl') for t in targets]
@@ -939,7 +968,7 @@ def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue:
939968
ret = self.convert_to_value(ret, hir.Span.from_ast(expr))
940969
return ret
941970
case ast.Subscript() | ast.Attribute():
942-
return self.parse_access(expr)
971+
return self.convert_to_value(self.parse_access_ref(expr), hir.Span.from_ast(expr))
943972
case ast.BinOp() | ast.Compare():
944973
return self.parse_binop(expr)
945974
case ast.UnaryOp():

0 commit comments

Comments
 (0)