Skip to content

Commit

Permalink
Fix encoder/decoder, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
magnified103 committed Nov 12, 2024
1 parent e5a0ad5 commit 1e1b9fc
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 60 deletions.
83 changes: 46 additions & 37 deletions packages/node/src/riblt/decoder.ts
Original file line number Diff line number Diff line change
@@ -1,62 +1,71 @@
import type { SourceSymbol, CodedSymbol, SourceSymbolFactory } from "./symbol.js";
import { type SourceSymbol, type CodedSymbol, type SourceSymbolFactory, HashedSymbol } from "./symbol.js";
import { RandomMapping } from "./mapping.js";
import { CodingPrefix } from "./encoder.js";


export class Decoder<T extends SourceSymbol> {
export class Decoder<T extends SourceSymbol> extends CodingPrefix<T> {
decodedSymbols: T[];
isDecoded: boolean[];
remaining: number;
pureSymbols: CodedSymbol<T>[];

constructor(private readonly sourceSymbolFactory: SourceSymbolFactory<T>) {
constructor(sourceSymbolFactory: SourceSymbolFactory<T>) {
super(sourceSymbolFactory);
this.decodedSymbols = [];
this.isDecoded = [];
this.remaining = 0;
this.pureSymbols = [];
}

tryDecode(local: CodedSymbol<T>[], remote: CodedSymbol<T>[]): boolean {
if (local.length !== remote.length) {
throw Error("The length of coded symbol sequences must be equal");
}

for (let i = 0; i < local.length; i++) {
local[i].xor(remote[i]);
extendPrefix(size: number): void {
super.extendPrefix(size);
while (this.isDecoded.length < size) {
this.isDecoded.push(false);
}
}

this.decodedSymbols = [];
const pureSymbols: CodedSymbol<T>[] = [];
let remaining = 0;
const isDecoded = new Array(local.length).fill(false);

for (let i = 0; i < local.length; i++) {
if (local[i].isZero()) {
isDecoded[i] = true;
} else {
remaining++;
if (local[i].isPure()) {
pureSymbols.push(local[i]);
}
// called at most once for each index
applyCodedSymbol(index: number, localSymbol: CodedSymbol<T>, remoteSymbol: CodedSymbol<T>): void {
this.extendPrefix(index + 1);
this.codedSymbols[index].apply(localSymbol, localSymbol.count);
this.codedSymbols[index].apply(remoteSymbol, -remoteSymbol.count);
if (this.codedSymbols[index].isZero()) {
this.isDecoded[index] = true;
} else {
this.remaining++;
if (this.codedSymbols[index].isPure()) {
this.pureSymbols.push(this.codedSymbols[index]);
}
}
}

while (pureSymbols.length > 0) {
const symbol = pureSymbols.pop() as CodedSymbol<T>;
tryDecode(): boolean {
while (this.pureSymbols.length > 0) {
const symbol = this.pureSymbols.pop() as CodedSymbol<T>;
// console.log(`pure symbol: ${symbol.sum.data} ${symbol.count}`);
if (symbol.isZero()) {
continue;
}
this.decodedSymbols.push(this.sourceSymbolFactory.clone(symbol.sum));
const decodedSymbol = this.sourceSymbolFactory.clone(symbol.sum)
this.decodedSymbols.push(decodedSymbol);

const mapping = new RandomMapping(symbol.checksum, 0);
while (mapping.lastIdx < local.length) {
while (mapping.lastIdx < this.codedSymbols.length) {
const idx = mapping.lastIdx;
if (isDecoded[idx]) {
continue;
}
local[idx].xor(symbol);
if (local[idx].isZero()) {
isDecoded[idx] = true;
remaining--;
} else if (local[idx].isPure()) {
pureSymbols.push(local[idx]);
if (!this.isDecoded[idx]) {
this.codedSymbols[idx].xor(symbol);
if (this.codedSymbols[idx].isZero()) {
this.isDecoded[idx] = true;
this.remaining--;
} else if (this.codedSymbols[idx].isPure()) {
this.pureSymbols.push(this.codedSymbols[idx]);
}
}
mapping.nextIndex();
}
this.addHashedSymbolWithMapping(new HashedSymbol<T>(decodedSymbol), mapping, -symbol.count);
}

return remaining === 0;
return this.remaining === 0;
}
}
27 changes: 13 additions & 14 deletions packages/node/src/riblt/encoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,15 @@ class MappingHeap {

export class CodingPrefix<T extends SourceSymbol> {
private sourceSymbols: HashedSymbol<T>[];
private sourceSymbolDirections: number[];
public codedSymbols: CodedSymbol<T>[];
private mapGenerators: RandomMapping[];
private queue: MappingHeap;

constructor(private readonly sourceSymbolFactory: SourceSymbolFactory<T>) {
constructor(protected readonly sourceSymbolFactory: SourceSymbolFactory<T>) {
this.sourceSymbols = [];
this.codedSymbols = [new CodedSymbol<T>(sourceSymbolFactory.empty(), sourceSymbolFactory.emptyHash())];
this.sourceSymbolDirections = [];
this.codedSymbols = [new CodedSymbol<T>(sourceSymbolFactory.empty(), sourceSymbolFactory.emptyHash(), 0)];
this.mapGenerators = [];
this.queue = new MappingHeap();
}
Expand All @@ -94,27 +96,32 @@ export class CodingPrefix<T extends SourceSymbol> {
this.addHashedSymbol(hashedSymbol);
}

addHashedSymbol(hashedSymbol: HashedSymbol<T>): void {
addHashedSymbol(hashedSymbol: HashedSymbol<T>, direction = 1): void {
const mapping = new RandomMapping(hashedSymbol.checksum, 0);
this.addHashedSymbolWithMapping(hashedSymbol, mapping);
this.addHashedSymbolWithMapping(hashedSymbol, mapping, direction);
}

addHashedSymbolWithMapping(
protected addHashedSymbolWithMapping(
hashedSymbol: HashedSymbol<T>,
mapping: RandomMapping,
direction = 1,
): void {
this.sourceSymbols.push(hashedSymbol);
this.sourceSymbolDirections.push(direction);
this.mapGenerators.push(mapping);
this.queue.push(new SymbolMapping(this.sourceSymbols.length - 1, mapping.lastIdx));
}

extendPrefix(size: number): void {
while (this.codedSymbols.length < size) {
this.codedSymbols.push(new CodedSymbol<T>(this.sourceSymbolFactory.empty(), this.sourceSymbolFactory.emptyHash(), 0));
}
while (this.queue.size > 0 && this.queue.top.codedIdx < size) {
const mapping = this.queue.pop();
while (mapping.codedIdx < size) {
const sourceIdx = mapping.sourceIdx;
const codedIdx = mapping.codedIdx;
this.codedSymbols[codedIdx].apply(this.sourceSymbols[sourceIdx], 1);
this.codedSymbols[codedIdx].apply(this.sourceSymbols[sourceIdx], this.sourceSymbolDirections[sourceIdx]);
mapping.codedIdx = this.mapGenerators[sourceIdx].nextIndex();
}
this.queue.push(mapping);
Expand All @@ -123,14 +130,6 @@ export class CodingPrefix<T extends SourceSymbol> {
}

export class Encoder<T extends SourceSymbol> extends CodingPrefix<T> {
addSymbol(s: T): void {
super.addSymbol(s);
}

addHashedSymbol(s: HashedSymbol<T>): void {
super.addHashedSymbol(s);
}

producePrefix(size: number): CodedSymbol<T>[] {
super.extendPrefix(size);
return this.codedSymbols.slice(0, size);
Expand Down
23 changes: 17 additions & 6 deletions packages/node/src/riblt/mapping.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
import * as crypto from "node:crypto";


export class RandomMapping {
private prng: bigint;
private state: Uint8Array;
lastIdx: number;

constructor(prng: number, lastIdx = 0) {
this.prng = BigInt(prng);
constructor(seed: Uint8Array, lastIdx = 0) {
this.state = crypto.createHash("sha1").update(seed).digest();
this.lastIdx = lastIdx;
}

nextIndex(): number {
// xorshift128++

let prng = 0n;
prng |= BigInt(this.state[0]) << 0n;
prng |= BigInt(this.state[1]) << 8n;
prng |= BigInt(this.state[2]) << 16n;
prng |= BigInt(this.state[3]) << 24n;
prng |= BigInt(this.state[4]) << 32n;
prng |= BigInt(this.state[5]) << 40n;
prng |= BigInt(this.state[6]) << 48n;
prng |= BigInt(this.state[7]) << 56n;
this.lastIdx += Math.ceil(
(this.lastIdx + 1.5) * (2 ** 32 / Math.sqrt(Number(this.prng) + 1) - 1),
(this.lastIdx + 1.5) * (2 ** 32 / Math.sqrt(Number(prng) + 1) - 1),
);
this.state = crypto.createHash("sha1").update(this.state).digest();
return this.lastIdx;
}
}
10 changes: 7 additions & 3 deletions packages/node/src/riblt/symbol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ export class HashedSymbol<T extends SourceSymbol> {
sum: T;
checksum: Uint8Array;

constructor(sum: T, checksum: Uint8Array) {
constructor(sum: T, checksum?: Uint8Array) {
this.sum = sum;
this.checksum = checksum;
if (checksum === undefined) {
this.checksum = sum.hash();
} else {
this.checksum = checksum;
}
}

xor(s: HashedSymbol<T>): void {
Expand All @@ -44,7 +48,7 @@ export class HashedSymbol<T extends SourceSymbol> {
export class CodedSymbol<T extends SourceSymbol> extends HashedSymbol<T> {
count: number;

constructor(sum: T, checksum: Uint8Array, count = 1) {
constructor(sum: T, checksum: Uint8Array, count: number) {
super(sum, checksum);
this.count = count;
}
Expand Down
86 changes: 86 additions & 0 deletions packages/node/tests/riblt.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import { beforeEach, describe, expect, test } from "vitest";
import { Encoder } from "../src/riblt/encoder.js";
import { Decoder } from "../src/riblt/decoder.js";
import type { SourceSymbolFactory, SourceSymbol } from "../src/riblt/symbol.js";
import * as crypto from 'node:crypto';


class VertexSymbol implements SourceSymbol {
data: Uint8Array;

constructor(data: Uint8Array) {
this.data = data;
}

xor(s: VertexSymbol): void {
for (let i = 0; i < this.data.length; i++) {
this.data[i] ^= s.data[i];
}
}

hash(): Uint8Array {
return crypto.createHash('sha1').update(this.data).digest();
}

equals(s: VertexSymbol): boolean {
for (let i = 0; i < this.data.length; i++) {
if (this.data[i] !== s.data[i]) {
return false;
}
}
return true;
}
}


class VertexSymbolFactory implements SourceSymbolFactory<VertexSymbol> {
empty(): VertexSymbol {
return new VertexSymbol(new Uint8Array(32));
}

emptyHash(): Uint8Array {
return new Uint8Array(20);
}

clone(s: VertexSymbol): VertexSymbol {
return new VertexSymbol(new Uint8Array(s.data));
}
}


describe("RIBLT test", async () => {
const factory = new VertexSymbolFactory();
const v0 = factory.empty();
const v1 = factory.empty();
const v2 = factory.empty();

v0.data[0] = 1;
v1.data[0] = 2;
v2.data[0] = 4;

const aliceEncoder = new Encoder(factory);
const bobEncoder = new Encoder(factory);

aliceEncoder.addSymbol(v0);
aliceEncoder.addSymbol(v2);

bobEncoder.addSymbol(v1);
bobEncoder.addSymbol(v2);

aliceEncoder.extendPrefix(10);
bobEncoder.extendPrefix(10);

const bobDecoder = new Decoder(factory);

for (let i = 0; i < 10; i++) {
// console.log(`${i}: ${aliceEncoder.codedSymbols[i].sum.data} ${aliceEncoder.codedSymbols[i].count} ${bobEncoder.codedSymbols[i].sum.data} ${bobEncoder.codedSymbols[i].count}`);
bobDecoder.applyCodedSymbol(i, aliceEncoder.codedSymbols[i], bobEncoder.codedSymbols[i]);
}

// for (let i = 0; i < 10; i++) {
// console.log(`Decoded: ${bobDecoder.codedSymbols[i].sum.data} ${bobDecoder.codedSymbols[i].count}`);
// }

expect(bobDecoder.tryDecode()).toBe(true);
console.log(bobDecoder.decodedSymbols);
});

0 comments on commit 1e1b9fc

Please sign in to comment.