Skip to content

Commit

Permalink
optimize away trivial blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
kripken committed Oct 29, 2015
1 parent 8874a52 commit 665b5a7
Show file tree
Hide file tree
Showing 5 changed files with 13,481 additions and 13,805 deletions.
1 change: 1 addition & 0 deletions check.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
actual, err = subprocess.Popen([os.path.join('bin', 'asm2wasm'), os.path.join('test', asm)], stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
assert err == '', 'bad err:' + err
if not os.path.exists(os.path.join('test', wasm)):
print actual
raise Exception('output .wast file does not exist')
expected = open(os.path.join('test', wasm)).read()
if actual != expected:
Expand Down
190 changes: 189 additions & 1 deletion src/asm2wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,149 @@ static void abort_on(std::string why, IString element) {
abort();
}

// useful when we need to see our parent, in an expression stack
//
// Simple WebAssembly optimizer, improves common patterns we get in asm2wasm.
// Operates in-place.
//

struct WasmWalker {
wasm::Arena* allocator; // use an existing allocator, or null if no allocations

WasmWalker() : allocator(nullptr) {}
WasmWalker(wasm::Arena* allocator) : allocator(allocator) {}

// Each method receives an AST pointer, and it is replaced with what is returned.
virtual Expression* walkBlock(Block *curr) { return curr; };
virtual Expression* walkIf(If *curr) { return curr; };
virtual Expression* walkLoop(Loop *curr) { return curr; };
virtual Expression* walkLabel(Label *curr) { return curr; };
virtual Expression* walkBreak(Break *curr) { return curr; };
virtual Expression* walkSwitch(Switch *curr) { return curr; };
virtual Expression* walkCall(Call *curr) { return curr; };
virtual Expression* walkCallImport(CallImport *curr) { return curr; };
virtual Expression* walkCallIndirect(CallIndirect *curr) { return curr; };
virtual Expression* walkGetLocal(GetLocal *curr) { return curr; };
virtual Expression* walkSetLocal(SetLocal *curr) { return curr; };
virtual Expression* walkLoad(Load *curr) { return curr; };
virtual Expression* walkStore(Store *curr) { return curr; };
virtual Expression* walkConst(Const *curr) { return curr; };
virtual Expression* walkUnary(Unary *curr) { return curr; };
virtual Expression* walkBinary(Binary *curr) { return curr; };
virtual Expression* walkCompare(Compare *curr) { return curr; };
virtual Expression* walkConvert(Convert *curr) { return curr; };
virtual Expression* walkHost(Host *curr) { return curr; };

// children-first
Expression *walk(Expression *curr) {
if (!curr) return curr;

if (Block *cast = dynamic_cast<Block*>(curr)) {
ExpressionList& list = cast->list;
for (size_t z = 0; z < list.size(); z++) {
list[z] = walk(list[z]);
}
return walkBlock(cast);
}
if (If *cast = dynamic_cast<If*>(curr)) {
cast->condition = walk(cast->condition);
cast->ifTrue = walk(cast->ifTrue);
cast->ifFalse = walk(cast->ifFalse);
return walkIf(cast);
}
if (Loop *cast = dynamic_cast<Loop*>(curr)) {
cast->body = walk(cast->body);
return walkLoop(cast);
}
if (Label *cast = dynamic_cast<Label*>(curr)) {
return walkLabel(cast);
}
if (Break *cast = dynamic_cast<Break*>(curr)) {
cast->condition = walk(cast->condition);
cast->value = walk(cast->value);
return walkBreak(cast);
}
if (Switch *cast = dynamic_cast<Switch*>(curr)) {
cast->value = walk(cast->value);
for (auto& curr : cast->cases) {
curr.body = walk(curr.body);
}
cast->default_ = walk(cast->default_);
return walkSwitch(cast);
}
if (Call *cast = dynamic_cast<Call*>(curr)) {
ExpressionList& list = cast->operands;
for (size_t z = 0; z < list.size(); z++) {
list[z] = walk(list[z]);
}
return walkCall(cast);
}
if (CallImport *cast = dynamic_cast<CallImport*>(curr)) {
ExpressionList& list = cast->operands;
for (size_t z = 0; z < list.size(); z++) {
list[z] = walk(list[z]);
}
return walkCallImport(cast);
}
if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) {
cast->target = walk(cast->target);
ExpressionList& list = cast->operands;
for (size_t z = 0; z < list.size(); z++) {
list[z] = walk(list[z]);
}
return walkCallIndirect(cast);
}
if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) {
return walkGetLocal(cast);
}
if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) {
cast->value = walk(cast->value);
return walkSetLocal(cast);
}
if (Load *cast = dynamic_cast<Load*>(curr)) {
cast->ptr = walk(cast->ptr);
return walkLoad(cast);
}
if (Store *cast = dynamic_cast<Store*>(curr)) {
cast->ptr = walk(cast->ptr);
cast->value = walk(cast->value);
return walkStore(cast);
}
if (Const *cast = dynamic_cast<Const*>(curr)) {
return walkConst(cast);
}
if (Unary *cast = dynamic_cast<Unary*>(curr)) {
cast->value = walk(cast->value);
return walkUnary(cast);
}
if (Binary *cast = dynamic_cast<Binary*>(curr)) {
cast->left = walk(cast->left);
cast->right = walk(cast->right);
return walkBinary(cast);
}
if (Compare *cast = dynamic_cast<Compare*>(curr)) {
cast->left = walk(cast->left);
cast->right = walk(cast->right);
return walkCompare(cast);
}
if (Convert *cast = dynamic_cast<Convert*>(curr)) {
cast->value = walk(cast->value);
return walkConvert(cast);
}
if (Host *cast = dynamic_cast<Host*>(curr)) {
ExpressionList& list = cast->operands;
for (size_t z = 0; z < list.size(); z++) {
list[z] = walk(list[z]);
}
return walkHost(cast);
}
}

void startWalk(Function *func) {
func->body = walk(func->body);
}
};

// useful when we need to see our parent, in an asm.js expression stack
struct AstStackHelper {
static std::vector<Ref> astStack;
AstStackHelper(Ref curr) {
Expand Down Expand Up @@ -178,7 +320,9 @@ class Asm2WasmModule : public wasm::Module {

public:
Asm2WasmModule() : nextGlobal(8), maxGlobal(1000) {}

void processAsm(Ref ast);
void optimize();

private:
BasicType asmToWasmType(AsmType asmType) {
Expand Down Expand Up @@ -960,6 +1104,47 @@ Function* Asm2WasmModule::processFunction(Ref ast) {
return function;
}

void Asm2WasmModule::optimize() {
struct BlockRemover : public WasmWalker {
BlockRemover() : WasmWalker(nullptr) {}

Expression* walkBlock(Block *curr) override {
if (curr->list.size() != 1) return curr;
// just one element; maybe we can return just the element
if (curr->var.isNull()) return curr->list[0];
// we might be broken to, but if it's a trivial singleton child break, we can optimize here as well
Break *child = dynamic_cast<Break*>(curr->list[0]);
if (!child || child->var != curr->var || !child->value) return curr;

struct BreakSeeker : public WasmWalker {
IString target; // look for this one
size_t found;

BreakSeeker(IString target) : target(target), found(false) {}

Expression* walkBreak(Break *curr) override {
if (curr->var == target) found++;
}
};

// look in the child's children to see if there are more uses of this var
BreakSeeker breakSeeker(curr->var);
breakSeeker.walk(child->condition);
breakSeeker.walk(child->value);
if (breakSeeker.found == 0) return child->value;

return curr; // failed to optimize
}
};

BlockRemover blockRemover;
for (auto function : functions) {
blockRemover.startWalk(function);
}
}

// main

int main(int argc, char **argv) {
debug = !!getenv("ASM2WASM_DEBUG") && getenv("ASM2WASM_DEBUG")[0] != '0';

Expand Down Expand Up @@ -1007,6 +1192,9 @@ int main(int argc, char **argv) {
Asm2WasmModule wasm;
wasm.processAsm(asmjs);

if (debug) std::cerr << "optimizing...\n";
wasm.optimize();

if (debug) std::cerr << "printing...\n";
wasm.print(std::cout);

Expand Down
1 change: 1 addition & 0 deletions src/istring.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct IString {

IString() : str(nullptr) {}
IString(const char *s, bool reuse=true) { // if reuse=true, then input is assumed to remain alive; not copied
assert(s);
set(s, reuse);
}

Expand Down
Loading

0 comments on commit 665b5a7

Please sign in to comment.