@@ -11,7 +11,10 @@ use qsc_data_structures::span::Span;
11
11
use qsc_frontend:: compile:: CompileUnit ;
12
12
use qsc_hir:: {
13
13
assigner:: Assigner ,
14
- hir:: { Block , CallableDecl , Expr , ExprKind , NodeId , Res , Stmt , StmtKind , Ty } ,
14
+ hir:: {
15
+ Block , CallableDecl , Expr , ExprKind , Ident , Mutability , NodeId , Pat , PatKind , Res , Stmt ,
16
+ StmtKind , Ty ,
17
+ } ,
15
18
mut_visit:: { self , MutVisitor } ,
16
19
visit:: { self , Visitor } ,
17
20
} ;
@@ -31,6 +34,9 @@ pub enum Error {
31
34
#[ error( "variable cannot be assigned in apply-block since it is used in within-block" ) ]
32
35
#[ diagnostic( help( "updating mutable variables in the apply-block that are used in the within-block can violate logic reversibility" ) ) ]
33
36
ApplyAssign ( #[ label] Span ) ,
37
+
38
+ #[ error( "return expressions are not allowed in apply-blocks" ) ]
39
+ ReturnForbidden ( #[ label] Span ) ,
34
40
}
35
41
36
42
/// Generates adjoint inverted blocks for within-blocks across all conjugate expressions,
@@ -72,8 +78,6 @@ struct ConjugateElim<'a> {
72
78
73
79
impl < ' a > MutVisitor for ConjugateElim < ' a > {
74
80
fn visit_expr ( & mut self , expr : & mut Expr ) {
75
- mut_visit:: walk_expr ( self , expr) ;
76
-
77
81
match take ( & mut expr. kind ) {
78
82
ExprKind :: Conjugate ( within, apply) => {
79
83
let mut usage = Usage {
@@ -87,6 +91,10 @@ impl<'a> MutVisitor for ConjugateElim<'a> {
87
91
assign_check. visit_block ( & apply) ;
88
92
self . errors . extend ( assign_check. errors ) ;
89
93
94
+ let mut return_check = ReturnCheck { errors : Vec :: new ( ) } ;
95
+ return_check. visit_block ( & apply) ;
96
+ self . errors . extend ( return_check. errors ) ;
97
+
90
98
let mut adj_within = within. clone ( ) ;
91
99
if let Err ( invert_errors) = adj_invert_block ( self . assigner , & mut adj_within) {
92
100
self . errors . extend (
@@ -102,28 +110,43 @@ impl<'a> MutVisitor for ConjugateElim<'a> {
102
110
self . errors
103
111
. extend ( distrib. errors . into_iter ( ) . map ( Error :: AdjGen ) ) ;
104
112
113
+ let ( bind_id, apply_as_bind) =
114
+ block_as_binding ( apply, expr. ty . clone ( ) , self . assigner ) ;
115
+
105
116
let new_block = Block {
106
117
id : NodeId :: default ( ) ,
107
118
span : Span :: default ( ) ,
108
- ty : Ty :: UNIT ,
119
+ ty : expr . ty . clone ( ) ,
109
120
stmts : vec ! [
110
121
block_as_stmt( within) ,
111
- block_as_stmt ( apply ) ,
122
+ apply_as_bind ,
112
123
block_as_stmt( adj_within) ,
124
+ Stmt {
125
+ id: NodeId :: default ( ) ,
126
+ span: Span :: default ( ) ,
127
+ kind: StmtKind :: Expr ( Expr {
128
+ id: NodeId :: default ( ) ,
129
+ span: Span :: default ( ) ,
130
+ ty: expr. ty. clone( ) ,
131
+ kind: ExprKind :: Var ( Res :: Local ( bind_id) ) ,
132
+ } ) ,
133
+ } ,
113
134
] ,
114
135
} ;
115
- * expr = block_as_expr ( new_block) ;
136
+ * expr = block_as_expr ( new_block, expr . ty . clone ( ) ) ;
116
137
}
117
138
kind => expr. kind = kind,
118
139
}
140
+
141
+ mut_visit:: walk_expr ( self , expr) ;
119
142
}
120
143
}
121
144
122
- fn block_as_expr ( block : Block ) -> Expr {
145
+ fn block_as_expr ( block : Block , ty : Ty ) -> Expr {
123
146
Expr {
124
147
id : NodeId :: default ( ) ,
125
148
span : Span :: default ( ) ,
126
- ty : Ty :: UNIT ,
149
+ ty,
127
150
kind : ExprKind :: Block ( block) ,
128
151
}
129
152
}
@@ -132,10 +155,35 @@ fn block_as_stmt(block: Block) -> Stmt {
132
155
Stmt {
133
156
id : NodeId :: default ( ) ,
134
157
span : Span :: default ( ) ,
135
- kind : StmtKind :: Expr ( block_as_expr ( block) ) ,
158
+ kind : StmtKind :: Expr ( block_as_expr ( block, Ty :: UNIT ) ) ,
136
159
}
137
160
}
138
161
162
+ fn block_as_binding ( block : Block , ty : Ty , assigner : & mut Assigner ) -> ( NodeId , Stmt ) {
163
+ let bind_id = assigner. next_id ( ) ;
164
+ (
165
+ bind_id,
166
+ Stmt {
167
+ id : assigner. next_id ( ) ,
168
+ span : Span :: default ( ) ,
169
+ kind : StmtKind :: Local (
170
+ Mutability :: Immutable ,
171
+ Pat {
172
+ id : assigner. next_id ( ) ,
173
+ span : Span :: default ( ) ,
174
+ ty : ty. clone ( ) ,
175
+ kind : PatKind :: Bind ( Ident {
176
+ id : bind_id,
177
+ span : Span :: default ( ) ,
178
+ name : "apply_res" . into ( ) ,
179
+ } ) ,
180
+ } ,
181
+ block_as_expr ( block, ty) ,
182
+ ) ,
183
+ } ,
184
+ )
185
+ }
186
+
139
187
struct Usage {
140
188
used : HashSet < NodeId > ,
141
189
}
@@ -186,3 +234,15 @@ impl AssignmentCheck {
186
234
}
187
235
}
188
236
}
237
+
238
+ struct ReturnCheck {
239
+ errors : Vec < Error > ,
240
+ }
241
+
242
+ impl < ' a > Visitor < ' a > for ReturnCheck {
243
+ fn visit_expr ( & mut self , expr : & ' a Expr ) {
244
+ if matches ! ( & expr. kind, ExprKind :: Return ( ..) ) {
245
+ self . errors . push ( Error :: ReturnForbidden ( expr. span ) ) ;
246
+ }
247
+ }
248
+ }
0 commit comments