Skip to content

Commit abba33d

Browse files
committed
Support control flow inputs in IRBuilder
Since multivalue was standardized, WebAssembly has supported not only multiple results but also an arbitrary number of inputs on control flow structures, but until now Binaryen did not support control flow input. Binaryen IR still has no way to represent control flow input, so lower it away using scratch locals in IRBuilder. Since both the text and binary parsers use IRBuilder, this gives us full support for parsing control flow inputs. The lowering scheme is mostly simple. A local.set writing the control flow inputs to a scratch local is inserted immediately before the control flow structure begins and a local.get retrieving those inputs is inserted inside the control flow structure before the rest of its body. The only complications come from ifs, in which the inputs must be retrieved at the beginning of both arms, and from loops, where branches to the beginning of the loop must be transformed so their values are written to the scratch local along the way. Resolves #6407.
1 parent 52bc45f commit abba33d

12 files changed

+824
-162
lines changed

src/parser/contexts.h

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,11 +1447,7 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> {
14471447

14481448
Result<HeapType> getBlockTypeFromTypeUse(Index pos, HeapType type) {
14491449
assert(type.isSignature());
1450-
if (type.getSignature().params != Type::none) {
1451-
return in.err(pos, "block parameters not yet supported");
1452-
}
1453-
// TODO: Once we support block parameters, return an error here if any of
1454-
// them are named.
1450+
// TODO: Error if block parameters are named
14551451
return type;
14561452
}
14571453

@@ -1822,20 +1818,23 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> {
18221818
HeapType type) {
18231819
// TODO: validate labels?
18241820
// TODO: Move error on input types to here?
1825-
return withLoc(pos,
1826-
irBuilder.makeBlock(label ? *label : Name{},
1827-
type.getSignature().results));
1821+
if (!type.isSignature()) {
1822+
return in.err(pos, "expected function type");
1823+
}
1824+
return withLoc(
1825+
pos, irBuilder.makeBlock(label ? *label : Name{}, type.getSignature()));
18281826
}
18291827

18301828
Result<> makeIf(Index pos,
18311829
const std::vector<Annotation>& annotations,
18321830
std::optional<Name> label,
18331831
HeapType type) {
18341832
// TODO: validate labels?
1835-
// TODO: Move error on input types to here?
1833+
if (!type.isSignature()) {
1834+
return in.err(pos, "expected function type");
1835+
}
18361836
return withLoc(
1837-
pos,
1838-
irBuilder.makeIf(label ? *label : Name{}, type.getSignature().results));
1837+
pos, irBuilder.makeIf(label ? *label : Name{}, type.getSignature()));
18391838
}
18401839

18411840
Result<> visitElse() { return withLoc(irBuilder.visitElse()); }
@@ -1845,21 +1844,23 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> {
18451844
std::optional<Name> label,
18461845
HeapType type) {
18471846
// TODO: validate labels?
1848-
// TODO: Move error on input types to here?
1847+
if (!type.isSignature()) {
1848+
return in.err(pos, "expected function type");
1849+
}
18491850
return withLoc(
1850-
pos,
1851-
irBuilder.makeLoop(label ? *label : Name{}, type.getSignature().results));
1851+
pos, irBuilder.makeLoop(label ? *label : Name{}, type.getSignature()));
18521852
}
18531853

18541854
Result<> makeTry(Index pos,
18551855
const std::vector<Annotation>& annotations,
18561856
std::optional<Name> label,
18571857
HeapType type) {
18581858
// TODO: validate labels?
1859-
// TODO: Move error on input types to here?
1859+
if (!type.isSignature()) {
1860+
return in.err(pos, "expected function type");
1861+
}
18601862
return withLoc(
1861-
pos,
1862-
irBuilder.makeTry(label ? *label : Name{}, type.getSignature().results));
1863+
pos, irBuilder.makeTry(label ? *label : Name{}, type.getSignature()));
18631864
}
18641865

18651866
Result<> makeTryTable(Index pos,
@@ -1875,12 +1876,10 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> {
18751876
labels.push_back(info.label);
18761877
isRefs.push_back(info.isRef);
18771878
}
1878-
return withLoc(pos,
1879-
irBuilder.makeTryTable(label ? *label : Name{},
1880-
type.getSignature().results,
1881-
tags,
1882-
labels,
1883-
isRefs));
1879+
return withLoc(
1880+
pos,
1881+
irBuilder.makeTryTable(
1882+
label ? *label : Name{}, type.getSignature(), tags, labels, isRefs));
18841883
}
18851884

18861885
Result<> visitCatch(Index pos, Name tag) {

src/wasm-binary.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,10 +1467,12 @@ class WasmBinaryReader {
14671467

14681468
bool getBasicType(int32_t code, Type& out);
14691469
bool getBasicHeapType(int64_t code, HeapType& out);
1470+
// Get the signature of control flow structure.
1471+
Signature getBlockType();
14701472
// Read a value and get a type for it.
14711473
Type getType();
14721474
// Get a type given the initial S32LEB has already been read, and is provided.
1473-
Type getType(int initial);
1475+
Type getType(int code);
14741476
HeapType getHeapType();
14751477
HeapType getIndexedHeapType();
14761478

src/wasm-ir-builder.h

Lines changed: 67 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,18 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
8080
// the corresponding `makeXYZ` function below instead of `visitXYZStart`, but
8181
// either way must call `visitEnd` and friends at the appropriate times.
8282
Result<> visitFunctionStart(Function* func);
83-
Result<> visitBlockStart(Block* block);
84-
Result<> visitIfStart(If* iff, Name label = {});
83+
Result<> visitBlockStart(Block* block, Type inputType = Type::none);
84+
Result<> visitIfStart(If* iff, Name label = {}, Type inputType = Type::none);
8585
Result<> visitElse();
86-
Result<> visitLoopStart(Loop* iff);
87-
Result<> visitTryStart(Try* tryy, Name label = {});
86+
Result<> visitLoopStart(Loop* iff, Type inputType = Type::none);
87+
Result<>
88+
visitTryStart(Try* tryy, Name label = {}, Type inputType = Type::none);
8889
Result<> visitCatch(Name tag);
8990
Result<> visitCatchAll();
9091
Result<> visitDelegate(Index label);
91-
Result<> visitTryTableStart(TryTable* trytable, Name label = {});
92+
Result<> visitTryTableStart(TryTable* trytable,
93+
Name label = {},
94+
Type inputType = Type::none);
9295
Result<> visitEnd();
9396

9497
// Used to visit break nodes when traversing a single block without its
@@ -113,9 +116,9 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
113116
// nodes. This is generally safer than calling `visit` because the function
114117
// signatures ensure that there are no missing fields.
115118
Result<> makeNop();
116-
Result<> makeBlock(Name label, Type type);
117-
Result<> makeIf(Name label, Type type);
118-
Result<> makeLoop(Name label, Type type);
119+
Result<> makeBlock(Name label, Signature sig);
120+
Result<> makeIf(Name label, Signature sig);
121+
Result<> makeLoop(Name label, Signature sig);
119122
Result<> makeBreak(Index label, bool isConditional);
120123
Result<> makeSwitch(const std::vector<Index>& labels, Index defaultLabel);
121124
// Unlike Builder::makeCall, this assumes the function already exists.
@@ -180,9 +183,9 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
180183
Result<> makeTableFill(Name table);
181184
Result<> makeTableCopy(Name destTable, Name srcTable);
182185
Result<> makeTableInit(Name elem, Name table);
183-
Result<> makeTry(Name label, Type type);
186+
Result<> makeTry(Name label, Signature sig);
184187
Result<> makeTryTable(Name label,
185-
Type type,
188+
Signature sig,
186189
const std::vector<Name>& tags,
187190
const std::vector<Index>& labels,
188191
const std::vector<bool>& isRefs);
@@ -323,13 +326,21 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
323326

324327
// The branch label name for this scope. Always fresh, never shadowed.
325328
Name label;
329+
326330
// For Try/Catch/CatchAll scopes, we need to separately track a label used
327331
// for branches, since the normal label is only used for delegates.
328332
Name branchLabel;
329333

330334
bool labelUsed = false;
331335

336+
// If the control flow scope has an input type, we need to lower it using a
337+
// scratch local because we cannot represent control flow input in the IR.
338+
Type inputType;
339+
Index inputLocal = -1;
340+
341+
// The stack of instructions being built in this scope.
332342
std::vector<Expression*> exprStack;
343+
333344
// Whether we have seen an unreachable instruction and are in
334345
// stack-polymorphic unreachable mode.
335346
bool unreachable = false;
@@ -338,29 +349,39 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
338349
size_t startPos = 0;
339350

340351
ScopeCtx() : scope(NoScope{}) {}
341-
ScopeCtx(Scope scope) : scope(scope) {}
342-
ScopeCtx(Scope scope, Name label, bool labelUsed)
343-
: scope(scope), label(label), labelUsed(labelUsed) {}
352+
ScopeCtx(Scope scope, Type inputType)
353+
: scope(scope), inputType(inputType) {}
354+
ScopeCtx(
355+
Scope scope, Name label, bool labelUsed, Type inputType, Index inputLocal)
356+
: scope(scope), label(label), labelUsed(labelUsed), inputType(inputType),
357+
inputLocal(inputLocal) {}
344358
ScopeCtx(Scope scope, Name label, bool labelUsed, Name branchLabel)
345359
: scope(scope), label(label), branchLabel(branchLabel),
346360
labelUsed(labelUsed) {}
347361

348362
static ScopeCtx makeFunc(Function* func) {
349-
return ScopeCtx(FuncScope{func});
363+
return ScopeCtx(FuncScope{func}, Type::none);
350364
}
351-
static ScopeCtx makeBlock(Block* block) {
352-
return ScopeCtx(BlockScope{block});
365+
static ScopeCtx makeBlock(Block* block, Type inputType) {
366+
return ScopeCtx(BlockScope{block}, inputType);
353367
}
354-
static ScopeCtx makeIf(If* iff, Name originalLabel = {}) {
355-
return ScopeCtx(IfScope{iff, originalLabel});
368+
static ScopeCtx makeIf(If* iff, Name originalLabel, Type inputType) {
369+
return ScopeCtx(IfScope{iff, originalLabel}, inputType);
356370
}
357-
static ScopeCtx
358-
makeElse(If* iff, Name originalLabel, Name label, bool labelUsed) {
359-
return ScopeCtx(ElseScope{iff, originalLabel}, label, labelUsed);
371+
static ScopeCtx makeElse(If* iff,
372+
Name originalLabel,
373+
Name label,
374+
bool labelUsed,
375+
Type inputType,
376+
Index inputLocal) {
377+
return ScopeCtx(
378+
ElseScope{iff, originalLabel}, label, labelUsed, inputType, inputLocal);
360379
}
361-
static ScopeCtx makeLoop(Loop* loop) { return ScopeCtx(LoopScope{loop}); }
362-
static ScopeCtx makeTry(Try* tryy, Name originalLabel = {}) {
363-
return ScopeCtx(TryScope{tryy, originalLabel});
380+
static ScopeCtx makeLoop(Loop* loop, Type inputType) {
381+
return ScopeCtx(LoopScope{loop}, inputType);
382+
}
383+
static ScopeCtx makeTry(Try* tryy, Name originalLabel, Type inputType) {
384+
return ScopeCtx(TryScope{tryy, originalLabel}, inputType);
364385
}
365386
static ScopeCtx makeCatch(Try* tryy,
366387
Name originalLabel,
@@ -378,8 +399,9 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
378399
return ScopeCtx(
379400
CatchAllScope{tryy, originalLabel}, label, labelUsed, branchLabel);
380401
}
381-
static ScopeCtx makeTryTable(TryTable* trytable, Name originalLabel = {}) {
382-
return ScopeCtx(TryTableScope{trytable, originalLabel});
402+
static ScopeCtx
403+
makeTryTable(TryTable* trytable, Name originalLabel, Type inputType) {
404+
return ScopeCtx(TryTableScope{trytable, originalLabel}, inputType);
383405
}
384406

385407
bool isNone() { return std::get_if<NoScope>(&scope); }
@@ -518,6 +540,7 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
518540
}
519541
WASM_UNREACHABLE("unexpected scope kind");
520542
}
543+
bool isDelimiter() { return getElse() || getCatch() || getCatchAll(); }
521544
};
522545

523546
// The stack of block contexts currently being parsed.
@@ -541,7 +564,7 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
541564
Index blockHint = 0;
542565
Index labelHint = 0;
543566

544-
void pushScope(ScopeCtx scope) {
567+
Result<> pushScope(ScopeCtx&& scope) {
545568
if (auto label = scope.getOriginalLabel()) {
546569
// Assign a fresh label to the scope, if necessary.
547570
if (!scope.label) {
@@ -554,7 +577,21 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
554577
scope.startPos = lastBinaryPos;
555578
lastBinaryPos = *binaryPos;
556579
}
557-
scopeStack.push_back(scope);
580+
bool hasInput = scope.inputType != Type::none;
581+
Index inputLocal = scope.inputLocal;
582+
if (hasInput && !scope.isDelimiter()) {
583+
if (inputLocal == Index(-1)) {
584+
auto scratch = addScratchLocal(scope.inputType);
585+
CHECK_ERR(scratch);
586+
inputLocal = scope.inputLocal = *scratch;
587+
}
588+
CHECK_ERR(makeLocalSet(inputLocal));
589+
}
590+
scopeStack.emplace_back(std::move(scope));
591+
if (hasInput) {
592+
CHECK_ERR(makeLocalGet(inputLocal));
593+
}
594+
return Ok{};
558595
}
559596

560597
ScopeCtx& getScope() {
@@ -610,6 +647,8 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
610647
Result<Type> getLabelType(Index label);
611648
Result<Type> getLabelType(Name labelName);
612649

650+
void fixLoopWithInput(Loop* loop, Type inputType, Index scratch);
651+
613652
void dump();
614653
};
615654

src/wasm/wasm-binary.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,30 +2121,30 @@ bool WasmBinaryReader::getBasicHeapType(int64_t code, HeapType& out) {
21212121
}
21222122
}
21232123

2124-
Type WasmBinaryReader::getType(int initial) {
2125-
// Single value types are negative; signature indices are non-negative
2126-
if (initial >= 0) {
2127-
// TODO: Handle block input types properly.
2128-
auto sig = getSignatureByTypeIndex(initial);
2129-
if (sig.params != Type::none) {
2130-
throwError("control flow inputs are not supported yet");
2131-
}
2132-
return sig.results;
2124+
Signature WasmBinaryReader::getBlockType() {
2125+
// Single value types are negative; signature indices are non-negative.
2126+
auto code = getS32LEB();
2127+
if (code >= 0) {
2128+
return getSignatureByTypeIndex(code);
2129+
}
2130+
if (code == BinaryConsts::EncodedType::Empty) {
2131+
return Signature();
21332132
}
2133+
return Signature(Type::none, getType(code));
2134+
}
2135+
2136+
Type WasmBinaryReader::getType(int code) {
21342137
Type type;
2135-
if (getBasicType(initial, type)) {
2138+
if (getBasicType(code, type)) {
21362139
return type;
21372140
}
2138-
switch (initial) {
2139-
// None only used for block signatures. TODO: Separate out?
2140-
case BinaryConsts::EncodedType::Empty:
2141-
return Type::none;
2141+
switch (code) {
21422142
case BinaryConsts::EncodedType::nullable:
21432143
return Type(getHeapType(), Nullable);
21442144
case BinaryConsts::EncodedType::nonnullable:
21452145
return Type(getHeapType(), NonNullable);
21462146
default:
2147-
throwError("invalid wasm type: " + std::to_string(initial));
2147+
throwError("invalid wasm type: " + std::to_string(code));
21482148
}
21492149
WASM_UNREACHABLE("unexpected type");
21502150
}
@@ -2885,11 +2885,11 @@ Result<> WasmBinaryReader::readInst() {
28852885
uint8_t code = getInt8();
28862886
switch (code) {
28872887
case BinaryConsts::Block:
2888-
return builder.makeBlock(Name(), getType());
2888+
return builder.makeBlock(Name(), getBlockType());
28892889
case BinaryConsts::If:
2890-
return builder.makeIf(Name(), getType());
2890+
return builder.makeIf(Name(), getBlockType());
28912891
case BinaryConsts::Loop:
2892-
return builder.makeLoop(Name(), getType());
2892+
return builder.makeLoop(Name(), getBlockType());
28932893
case BinaryConsts::Br:
28942894
return builder.makeBreak(getU32LEB(), false);
28952895
case BinaryConsts::BrIf:
@@ -2974,9 +2974,9 @@ Result<> WasmBinaryReader::readInst() {
29742974
case BinaryConsts::TableSet:
29752975
return builder.makeTableSet(getTableName(getU32LEB()));
29762976
case BinaryConsts::Try:
2977-
return builder.makeTry(Name(), getType());
2977+
return builder.makeTry(Name(), getBlockType());
29782978
case BinaryConsts::TryTable: {
2979-
auto type = getType();
2979+
auto type = getBlockType();
29802980
std::vector<Name> tags;
29812981
std::vector<Index> labels;
29822982
std::vector<bool> isRefs;

0 commit comments

Comments
 (0)