Skip to content

Commit

Permalink
Merge pull request #43689 from heshanpadmasiri/fix/bdd-optim
Browse files Browse the repository at this point in the history
Optimize BDD operations
  • Loading branch information
chiranSachintha authored Dec 9, 2024
2 parents 9b642cc + 5e99d27 commit 0c01583
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull_request_full_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
strategy:
fail-fast: false
matrix:
level: [ 1, 2, 3, 4, 5, 6, 7, 8, 9 ]
level: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ]

steps:
- name: Checkout Repository
Expand Down
112 changes: 87 additions & 25 deletions semtypes/src/main/java/io/ballerina/types/typeops/BddCommonOps.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import io.ballerina.types.subtypedata.BddAllOrNothing;
import io.ballerina.types.subtypedata.BddNode;

import java.util.HashMap;
import java.util.Map;

/**
* Contain common BDD operations found in bdd.bal file.
*
Expand All @@ -38,6 +41,21 @@ public static BddNode bddAtom(Atom atom) {
}

public static Bdd bddUnion(Bdd b1, Bdd b2) {
return bddUnionWithMemo(BddOpMemo.create(), b1, b2);
}

private static Bdd bddUnionWithMemo(BddOpMemo memoTable, Bdd b1, Bdd b2) {
BddOpMemoKey key = new BddOpMemoKey(b1, b2);
Bdd memoized = memoTable.unionMemo.get(key);
if (memoized != null) {
return memoized;
}
memoized = bddUnionInner(memoTable, b1, b2);
memoTable.unionMemo.put(key, memoized);
return memoized;
}

private static Bdd bddUnionInner(BddOpMemo memo, Bdd b1, Bdd b2) {
if (b1 == b2) {
return b1;
} else if (b1 instanceof BddAllOrNothing) {
Expand All @@ -51,23 +69,38 @@ public static Bdd bddUnion(Bdd b1, Bdd b2) {
if (cmp < 0L) {
return bddCreate(b1Bdd.atom(),
b1Bdd.left(),
bddUnion(b1Bdd.middle(), b2),
bddUnionWithMemo(memo, b1Bdd.middle(), b2),
b1Bdd.right());
} else if (cmp > 0L) {
return bddCreate(b2Bdd.atom(),
b2Bdd.left(),
bddUnion(b1, b2Bdd.middle()),
bddUnionWithMemo(memo, b1, b2Bdd.middle()),
b2Bdd.right());
} else {
return bddCreate(b1Bdd.atom(),
bddUnion(b1Bdd.left(), b2Bdd.left()),
bddUnion(b1Bdd.middle(), b2Bdd.middle()),
bddUnion(b1Bdd.right(), b2Bdd.right()));
bddUnionWithMemo(memo, b1Bdd.left(), b2Bdd.left()),
bddUnionWithMemo(memo, b1Bdd.middle(), b2Bdd.middle()),
bddUnionWithMemo(memo, b1Bdd.right(), b2Bdd.right()));
}
}
}

public static Bdd bddIntersect(Bdd b1, Bdd b2) {
return bddIntersectWithMemo(BddOpMemo.create(), b1, b2);
}

private static Bdd bddIntersectWithMemo(BddOpMemo memo, Bdd b1, Bdd b2) {
BddOpMemoKey key = new BddOpMemoKey(b1, b2);
Bdd memoized = memo.intersectionMemo.get(key);
if (memoized != null) {
return memoized;
}
memoized = bddIntersectInner(memo, b1, b2);
memo.intersectionMemo.put(key, memoized);
return memoized;
}

private static Bdd bddIntersectInner(BddOpMemo memo, Bdd b1, Bdd b2) {
if (b1 == b2) {
return b1;
} else if (b1 instanceof BddAllOrNothing) {
Expand All @@ -80,28 +113,43 @@ public static Bdd bddIntersect(Bdd b1, Bdd b2) {
long cmp = atomCmp(b1Bdd.atom(), b2Bdd.atom());
if (cmp < 0L) {
return bddCreate(b1Bdd.atom(),
bddIntersect(b1Bdd.left(), b2),
bddIntersect(b1Bdd.middle(), b2),
bddIntersect(b1Bdd.right(), b2));
bddIntersectWithMemo(memo, b1Bdd.left(), b2),
bddIntersectWithMemo(memo, b1Bdd.middle(), b2),
bddIntersectWithMemo(memo, b1Bdd.right(), b2));
} else if (cmp > 0L) {
return bddCreate(b2Bdd.atom(),
bddIntersect(b1, b2Bdd.left()),
bddIntersect(b1, b2Bdd.middle()),
bddIntersect(b1, b2Bdd.right()));
bddIntersectWithMemo(memo, b1, b2Bdd.left()),
bddIntersectWithMemo(memo, b1, b2Bdd.middle()),
bddIntersectWithMemo(memo, b1, b2Bdd.right()));
} else {
return bddCreate(b1Bdd.atom(),
bddIntersect(
bddUnion(b1Bdd.left(), b1Bdd.middle()),
bddUnion(b2Bdd.left(), b2Bdd.middle())),
bddIntersectWithMemo(memo,
bddUnionWithMemo(memo, b1Bdd.left(), b1Bdd.middle()),
bddUnionWithMemo(memo, b2Bdd.left(), b2Bdd.middle())),
BddAllOrNothing.bddNothing(),
bddIntersect(
bddUnion(b1Bdd.right(), b1Bdd.middle()),
bddUnion(b2Bdd.right(), b2Bdd.middle())));
bddIntersectWithMemo(memo,
bddUnionWithMemo(memo, b1Bdd.right(), b1Bdd.middle()),
bddUnionWithMemo(memo, b2Bdd.right(), b2Bdd.middle())));
}
}
}

public static Bdd bddDiff(Bdd b1, Bdd b2) {
return bddDiffWithMemo(BddOpMemo.create(), b1, b2);
}

private static Bdd bddDiffWithMemo(BddOpMemo memo, Bdd b1, Bdd b2) {
BddOpMemoKey key = new BddOpMemoKey(b1, b2);
Bdd memoized = memo.diffMemo.get(key);
if (memoized != null) {
return memoized;
}
memoized = bddDiffInner(memo, b1, b2);
memo.diffMemo.put(key, memoized);
return memoized;
}

private static Bdd bddDiffInner(BddOpMemo memo, Bdd b1, Bdd b2) {
if (b1 == b2) {
return BddAllOrNothing.bddNothing();
} else if (b2 instanceof BddAllOrNothing allOrNothing) {
Expand All @@ -114,25 +162,27 @@ public static Bdd bddDiff(Bdd b1, Bdd b2) {
long cmp = atomCmp(b1Bdd.atom(), b2Bdd.atom());
if (cmp < 0L) {
return bddCreate(b1Bdd.atom(),
bddDiff(bddUnion(b1Bdd.left(), b1Bdd.middle()), b2),
bddDiffWithMemo(memo, bddUnionWithMemo(memo, b1Bdd.left(), b1Bdd.middle()), b2),
BddAllOrNothing.bddNothing(),
bddDiff(bddUnion(b1Bdd.right(), b1Bdd.middle()), b2));
bddDiffWithMemo(memo, bddUnionWithMemo(memo, b1Bdd.right(), b1Bdd.middle()), b2));
} else if (cmp > 0L) {
return bddCreate(b2Bdd.atom(),
bddDiff(b1, bddUnion(b2Bdd.left(), b2Bdd.middle())),
bddDiffWithMemo(memo, b1, bddUnionWithMemo(memo, b2Bdd.left(), b2Bdd.middle())),
BddAllOrNothing.bddNothing(),
bddDiff(b1, bddUnion(b2Bdd.right(), b2Bdd.middle())));
bddDiffWithMemo(memo, b1, bddUnionWithMemo(memo, b2Bdd.right(), b2Bdd.middle())));
} else {
// There is an error in the Castagna paper for this formula.
// The union needs to be materialized here.
// The original formula does not work in a case like (a0|a1) - a0.
// Castagna confirms that the following formula is the correct one.
return bddCreate(b1Bdd.atom(),
bddDiff(bddUnion(b1Bdd.left(), b1Bdd.middle()),
bddUnion(b2Bdd.left(), b2Bdd.middle())),
bddDiffWithMemo(memo,
bddUnionWithMemo(memo, b1Bdd.left(), b1Bdd.middle()),
bddUnionWithMemo(memo, b2Bdd.left(), b2Bdd.middle())),
BddAllOrNothing.bddNothing(),
bddDiff(bddUnion(b1Bdd.right(), b1Bdd.middle()),
bddUnion(b2Bdd.right(), b2Bdd.middle())));
bddDiffWithMemo(memo,
bddUnionWithMemo(memo, b1Bdd.right(), b1Bdd.middle()),
bddUnionWithMemo(memo, b2Bdd.right(), b2Bdd.middle())));
}
}
}
Expand Down Expand Up @@ -222,4 +272,16 @@ public static String bddToString(Bdd b, boolean inner) {
return str;
}
}

private record BddOpMemoKey(Bdd b1, Bdd b2) {

}

private record BddOpMemo(Map<BddOpMemoKey, Bdd> unionMemo, Map<BddOpMemoKey, Bdd> intersectionMemo,
Map<BddOpMemoKey, Bdd> diffMemo) {

static BddOpMemo create() {
return new BddOpMemo(new HashMap<>(), new HashMap<>(), new HashMap<>());
}
}
}

0 comments on commit 0c01583

Please sign in to comment.