Skip to content

Commit fd66c6f

Browse files
committed
update readme
1 parent dfbdf3b commit fd66c6f

File tree

1 file changed

+72
-16
lines changed

1 file changed

+72
-16
lines changed

README.md

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ A new Python DSL frontend for LuisaCompute. Will be integrated into LuisaCompute
77
- [Basics](#basic-syntax)
88
- [Difference from Python](#difference-from-python)
99
- [Types](#types)
10-
- [Value & Reference Semantics](#value--reference-semantics)
10+
- [Value Semantics](#value-semantics)
11+
- [Local References](#local-references)
1112
- [Functions](#functions)
1213
- [User-defined Structs](#user-defined-structs)
1314
- [Control Flow](#control-flow)
@@ -47,7 +48,7 @@ def add(a: lc.float, b: lc.float) -> lc.float:
4748
```
4849

4950

50-
### Value & Reference Semantics
51+
### Value Semantics
5152
Variables have value semantics by default. This means that when you assign a variable to another, a copy is made.
5253
```python
5354
a = lc.float3(1.0, 2.0, 3.0)
@@ -56,36 +57,62 @@ a.x = 2.0
5657
lc.print(f'{a.x} {b.x}') # prints 2.0 1.0
5758
```
5859

59-
You can use `byref` to indicate that a variable is passed as a *local reference*. Assigning to an `byref` variable will update the original variable.
60+
#### Local References
61+
There is a logical reference type `lc.Ref[T]` that can be used to pass a value by reference, similar to `inout` in GLSL/HLSL.
6062
```python
61-
@luisa.func(a=byref, b=byref)
62-
def swap(a: int, b: int):
63+
@luisa.func
64+
def swap(a: lc.Ref[lc.float], b: lc.Ref[lc.float]):
6365
a, b = b, a
6466

6567
a = lc.float3(1.0, 2.0, 3.0)
6668
b = lc.float3(4.0, 5.0, 6.0)
6769
swap(a.x, b.x)
6870
lc.print(f'{a.x} {b.x}') # prints 4.0 1.0
6971
```
72+
However, `lc.Ref[T]` is more powerful than `inout` in GLSL/HLSL. You can even return a reference from a function and use it later.
7073

71-
When overloading subscript operator or attribute access, you actually return a local reference to the object.
74+
```python
75+
@lc.func
76+
def get_ref(a: lc.float3) -> lc.Ref[lc.float]:
77+
return a.x
7278

73-
#### Local References
74-
Local references are like pointers in C++. However, they cannot escape the expression boundary. This means that you cannot store a local reference in a variable and use it later. While you can return a local reference from a function, it must be returned from a uniform path. That is you cannot return different local references based on a condition.
79+
a = lc.float3(1.0, 2.0, 3.0)
80+
b = byref(get_ref(a)) # byref is necessary to indicate that the argument is passed by reference
81+
b = 2.0
82+
lc.print(f'{a.x} {b}') # prints 2.0 2.0
83+
```
84+
85+
**Important**: `lc.Ref[T]` is not a true reference type nor a pointer. It is a logical reference that is resolved at compile time. This means that you cannot store a `lc.Ref[T]` in an aggregate type, such as an array or a struct. If you want to return a reference from a function, the function must be inlineable. You also cannot define a local reference inside non-uniform control flow such as `if` or `for` statements. See the following example for the semantics of local references.
86+
```python
87+
a: lc.Ref[T] = byref(some_ref_func()) # a is bound to the reference returned by some_ref_func()
88+
if cond():
89+
a = another_ref_func() # does not bound `a` to a new reference, but changes the value of the reference
90+
b: lc.Ref[T] = another_ref_func() # error, cannot define a local reference inside non-uniform control flow
91+
# to workaround the above issue, you should define a new scope
92+
@lc.block
93+
def inner():
94+
b: lc.Ref[T] = another_ref_func() # this is fine
95+
# do something with b
96+
inner()
97+
```
98+
Further more, when matching template arguments, matching `lc.Ref[T]` to a template argument `U` would result in `U` being `T` instead of `lc.Ref[T]`.
99+
To force `U` to be `lc.Ref[T]`, you can use `lc.Ref[U]` as the template argument.
75100

76101

102+
Certain special methods must return a local reference. For example, `__getitem__` and `__getattr__` must return a local reference.
103+
77104
```python
78105
@lc.struct
79106
class InfiniteArray:
80-
def __getitem__(self, index: int) -> int:
107+
def __getitem__(self, index: int) -> lc.Ref[int]:
81108
return self.data[index] # returns a local reference
82109

83110
# this method will be ignored by the compiler. but you can still put it here for linting
84111
def __setitem__(self, index: int, value: int):
85112
pass
86113

87114
# Not allowed, non-uniform return
88-
def __getitem__(self, index: int) -> int:
115+
def __getitem__(self, index: int) -> lc.Ref[int]:
89116
if index == 0:
90117
return self.data[0]
91118
else:
@@ -94,9 +121,6 @@ class InfiniteArray:
94121
```
95122

96123

97-
98-
99-
100124
### User-defined Structs
101125
```python
102126
@lc.struct
@@ -105,6 +129,39 @@ class Sphere:
105129
radius: lc.float
106130
```
107131

132+
133+
### Control Flow
134+
```python
135+
# the following control flow constructs are supported
136+
if cond:
137+
pass
138+
elif cond:
139+
pass
140+
else:
141+
pass
142+
143+
while cond:
144+
pass
145+
146+
for i in lc.range(10):
147+
pass
148+
```
149+
Additionally, we provide a `lc.block` decorator that can be used to define a block of code that can be inlined into other functions. This is useful for defining shadowing variables or local references.
150+
151+
```python
152+
a = 1
153+
b = 2
154+
@lc.block
155+
def inner():
156+
nonlocal b
157+
a = 2
158+
b = 3
159+
inner()
160+
lc.print(a) # prints 1
161+
lc.print(b) # prints 3
162+
163+
```
164+
108165
### Define DSL Operation for Non-DSL Types
109166
Sometimes we want to use a non-DSL type in our DSL code. Such type could be imported from a third-party library or a built-in Python type. As long as we know the object layout, we can define the DSL operation for it by first defining a proxy struct that mirrors the object layout, and then define the operation for the proxy struct.
110167

@@ -176,9 +233,8 @@ def call_n_times(f: F):
176233
# or
177234
lc.embed_code('apply_func(f, i)')
178235

179-
# Hint a parameter is constexpr
180-
@lc.func(n=lc.comptime) # without this, n will be treated as a runtime variable and result in an error
181-
def pow(x: lc.float, n: int) -> lc.float:
236+
@lc.func
237+
def pow(x: lc.float, n: lc.Comptime[int]) -> lc.float:
182238
p = 1.0
183239
with lc.comptime():
184240
for _ in range(n):

0 commit comments

Comments
 (0)